maltehb commited on
Commit
1b48955
1 Parent(s): 3d6951e

removing all files

Browse files
LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2021 Dan Saattrup Nielsen
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md DELETED
@@ -1,21 +0,0 @@
1
- ---
2
- language: da
3
- license: CC-BY 4.0
4
- tags:
5
- - danish
6
- - roberta
7
- pipeline_tag: fill-mask
8
- widget:
9
- - text: "På biblioteket kan du låne en <mask>."
10
- ---
11
-
12
-
13
- # Danish Roberta Base - MC4
14
-
15
- ## Description
16
-
17
- This is a sample reference model for Flax/Jax training using only on the MC4. It is trained for roughly three day on a TPU v3-8. Training procedure...
18
-
19
- ---
20
- ## Description
21
- My description
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json DELETED
@@ -1,27 +0,0 @@
1
- {
2
- "_name_or_path": "./",
3
- "architectures": [
4
- "RobertaForMaskedLM"
5
- ],
6
- "attention_probs_dropout_prob": 0.1,
7
- "bos_token_id": 0,
8
- "eos_token_id": 2,
9
- "gradient_checkpointing": false,
10
- "hidden_act": "gelu",
11
- "hidden_dropout_prob": 0.1,
12
- "hidden_size": 768,
13
- "initializer_range": 0.02,
14
- "intermediate_size": 3072,
15
- "layer_norm_eps": 1e-05,
16
- "max_position_embeddings": 514,
17
- "model_type": "roberta",
18
- "num_attention_heads": 12,
19
- "num_hidden_layers": 12,
20
- "pad_token_id": 1,
21
- "position_embedding_type": "absolute",
22
- "torch_dtype": "float32",
23
- "transformers_version": "4.9.0.dev0",
24
- "type_vocab_size": 1,
25
- "use_cache": true,
26
- "vocab_size": 50265
27
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
continue_run_mlm_flax_stream.sh DELETED
@@ -1,26 +0,0 @@
1
- export MODEL_DIR=/home/Z6HJB/roberta-base-danish/roberta-base-danish/
2
-
3
- source /home/Z6HJB/test/bin/activate
4
-
5
- python3 ./src/run_mlm_flax_stream.py \
6
- --model_name_or_path="${MODEL_DIR}" \
7
- --output_dir="${MODEL_DIR}" \
8
- --tokenizer_name="${MODEL_DIR}" \
9
- --dataset_name="mc4" \
10
- --dataset_config_name="unshuffled_deduplicated_en" \
11
- --max_seq_length="514" \
12
- --per_device_train_batch_size="32" \
13
- --per_device_eval_batch_size="32" \
14
- --learning_rate="3e-4" \
15
- --warmup_steps="1000" \
16
- --overwrite_output_dir \
17
- --adam_beta1="0.9" \
18
- --adam_beta2="0.98" \
19
- --num_train_steps="100000" \
20
- --num_eval_samples="5000" \
21
- --save_steps="1000" \
22
- --logging_steps="250" \
23
- --eval_steps="1000" \
24
- #--push_to_hub \
25
- #--config_name="${MODEL_DIR}" \
26
- #--model_type="roberta" \
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
events.out.tfevents.1625976092.t1v-n-ba9840ed-w-0.105113.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:24ed25853743a37a86892767785e9a9a30f82ef00ddb5250a1c3f7aa00fe9e0b
3
- size 4412262
 
 
 
 
events.out.tfevents.1625999454.t1v-n-ba9840ed-w-0.129483.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c57d5be8414b153a10ee9b8ca9d21bb2232fb9129ac5d521e5f5c94af5f9ae1f
3
- size 10389142
 
 
 
 
flax_model.msgpack DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c221eb636b60017750c5a40abee0d1578ad3425f7b8460a9d80ee59e0ae13b8b
3
- size 498796983
 
 
 
 
makefile DELETED
@@ -1,18 +0,0 @@
1
- train:
2
- python3 ./src/run_mlm_flax.py \
3
- --output_dir="." \
4
- --model_type="roberta" \
5
- --config_name="." \
6
- --tokenizer_name="." \
7
- --max_seq_length="128" \
8
- --weight_decay="0.01" \
9
- --per_device_train_batch_size="128" \
10
- --per_device_eval_batch_size="128" \
11
- --learning_rate="3e-4" \
12
- --warmup_steps="1000" \
13
- --overwrite_output_dir \
14
- --pad_to_max_length \
15
- --num_train_epochs="18" \
16
- --adam_beta1="0.9" \
17
- --adam_beta2="0.98" \
18
- --push_to_hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
md_logs/train_tokenizer.md DELETED
@@ -1,70 +0,0 @@
1
- # Setting up a Google Cloud TPU VM for training a tokenizer
2
-
3
- ## TPU VM Configurations
4
- To start off follow the guide from the Flax/JAX community week 2021 [here](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#how-to-setup-tpu-vm), but **NOTE** modify all the `pip` commands to `pip3`.
5
-
6
- Some might encounter this error message:
7
- ```
8
- Building wheel for jax (setup.py) ... error
9
- ERROR: Command errored out with exit status 1:
10
- command: /home/patrick/patrick/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"'; __file__='"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' bdist_wheel -d /tmp/pip-wheel-pydotzlo
11
- cwd: /tmp/pip-install-lwseckn1/jax/
12
- Complete output (6 lines):
13
- usage: setup.py [global_opts] cmd1 [cmd1_opts] [cmd2 [cmd2_opts] ...]
14
- or: setup.py --help [cmd1 cmd2 ...]
15
- or: setup.py --help-commands
16
- or: setup.py cmd --help
17
-
18
- error: invalid command 'bdist_wheel'
19
- ----------------------------------------
20
- ERROR: Failed building wheel for jax
21
- ```
22
-
23
- If encountering the error message run the following commands:
24
- ```
25
- pip3 install --upgrade clu
26
- pip3 install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
27
- ```
28
-
29
- Then give your user sudo rights:
30
- ```bash
31
- chmod a+rwx /tmp/*
32
- chmod a+rwx /tmp/tpu_logs/* # Just to be sure ;-)
33
- ```
34
-
35
- Afterwards you can verify the installation by either running the following script:
36
-
37
- ```python
38
- from transformers import FlaxRobertaModel, RobertaTokenizerFast
39
- from datasets import load_dataset
40
- import jax
41
-
42
- dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
43
-
44
- dummy_input = next(iter(dataset))["text"]
45
-
46
- tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
47
- input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10]
48
-
49
- model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")
50
-
51
- # run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
52
- model(input_ids)
53
- ```
54
-
55
- Or by running the following `python` commands:
56
- ```python
57
- import jax
58
- jax.devices()
59
- ```
60
-
61
- ## Training the tokenizer
62
- To train the tokenizer run the `train_tokenizer.py` script:
63
- ```bash
64
- python3 train_tokenizer.py
65
- ```
66
-
67
- ### Problems while developing the script:
68
- - Loading the '*mc4*' dataset using the `load_dataset()` from HugginFace's dataset package `datasets` was not able to load multiple language in one line of code, as otherwise specified [here](https://huggingface.co/datasets/mc4). It was thus chosen to load each language and concatenate them.
69
- - Furthermore, it seems that even though you predefine a subset-split using the `split` argument, the entire dataset still needs to be downloaded.
70
- - Some bug occured when downloading the danish dataset, and we then had to force a redownload to mitigate the bug, and make the VM download it.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
merges.txt DELETED
The diff for this file is too large to render. See raw diff
 
pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:aedd97cdaaa3ac8422b20c4a391097d2af86308bd130045bc630d0ef222b183d
3
- size 498858859
 
 
 
 
run_mlm_flax_stream.sh DELETED
@@ -1,26 +0,0 @@
1
- export MODEL_DIR=/home/Z6HJB/roberta-base-danish/roberta-base-danish/
2
-
3
- source /home/Z6HJB/test/bin/activate
4
-
5
- python3 ./src/run_mlm_flax_stream.py \
6
- --config_name="${MODEL_DIR}" \
7
- --output_dir="${MODEL_DIR}" \
8
- --tokenizer_name="${MODEL_DIR}" \
9
- --model_type="roberta" \
10
- --dataset_name="mc4" \
11
- --dataset_config_name="unshuffled_deduplicated_en" \
12
- --max_seq_length="128" \
13
- --per_device_train_batch_size="128" \
14
- --per_device_eval_batch_size="128" \
15
- --learning_rate="3e-4" \
16
- --warmup_steps="1000" \
17
- --overwrite_output_dir \
18
- --adam_beta1="0.9" \
19
- --adam_beta2="0.98" \
20
- --num_train_steps="300000" \
21
- --num_eval_samples="5000" \
22
- --save_steps="1000" \
23
- --logging_steps="250" \
24
- --eval_steps="1000" \
25
- #--push_to_hub \ currently results in this error: ValueError: If not specifying `clone_from`, you need to pass Repository a valid git clone.
26
- #--model_name_or_path="${MODEL_DIR}" \ used to continue pretrained
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
special_tokens_map.json DELETED
@@ -1 +0,0 @@
1
- {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
 
 
src/config.py DELETED
@@ -1,9 +0,0 @@
1
- '''Create the configuration for the model'''
2
-
3
- from transformers import RobertaConfig
4
- from .utils import model_dir
5
-
6
- # Currently it merely copies the `roberta-base` config, but we can change this
7
- # of course
8
- config = RobertaConfig.from_pretrained("roberta-base")
9
- config.save_pretrained(model_dir)
 
 
 
 
 
 
 
 
 
 
src/danish_run_mlm_flax_stream.py DELETED
@@ -1,635 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2021 The HuggingFace Team All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
- text file or a dataset.
19
- Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
- https://huggingface.co/models?filter=masked-lm
21
- """
22
- import logging
23
- import os
24
- import sys
25
- import time
26
- from collections import defaultdict
27
- from dataclasses import dataclass, field
28
-
29
- # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
30
- from pathlib import Path
31
- from typing import Dict, List, Optional, Tuple
32
-
33
- import datasets
34
- import numpy as np
35
- from datasets import load_dataset, interleave_datasets
36
- from tqdm import tqdm
37
-
38
- import flax
39
- import jax
40
- import jax.numpy as jnp
41
- import optax
42
- from flax import jax_utils, traverse_util
43
- from flax.training import train_state
44
- from flax.training.common_utils import get_metrics, onehot, shard
45
- from transformers import (
46
- CONFIG_MAPPING,
47
- FLAX_MODEL_FOR_MASKED_LM_MAPPING,
48
- AutoConfig,
49
- AutoTokenizer,
50
- FlaxAutoModelForMaskedLM,
51
- HfArgumentParser,
52
- PreTrainedTokenizerBase,
53
- TensorType,
54
- TrainingArguments,
55
- is_tensorboard_available,
56
- set_seed,
57
- )
58
-
59
-
60
- if datasets.__version__ <= "1.8.0":
61
- raise ValueError("Make sure to upgrade `datasets` to a version >= 1.9.0 to use dataset streaming")
62
-
63
-
64
- MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
65
- MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
66
-
67
-
68
- @dataclass
69
- class ModelArguments:
70
- """
71
- Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
72
- """
73
-
74
- model_name_or_path: Optional[str] = field(
75
- default=None,
76
- metadata={
77
- "help": "The model checkpoint for weights initialization."
78
- "Don't set if you want to train a model from scratch."
79
- },
80
- )
81
- model_type: Optional[str] = field(
82
- default=None,
83
- metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
84
- )
85
- config_name: Optional[str] = field(
86
- default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
87
- )
88
- tokenizer_name: Optional[str] = field(
89
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
90
- )
91
- cache_dir: Optional[str] = field(
92
- default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
93
- )
94
- use_fast_tokenizer: bool = field(
95
- default=True,
96
- metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
97
- )
98
- dtype: Optional[str] = field(
99
- default="float32",
100
- metadata={
101
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
102
- },
103
- )
104
-
105
-
106
- @dataclass
107
- class DataTrainingArguments:
108
- """
109
- Arguments pertaining to what data we are going to input our model for training and eval.
110
- """
111
-
112
- dataset_name: Optional[str] = field(
113
- default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
114
- )
115
- dataset_config_name: Optional[str] = field(
116
- default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
117
- )
118
- train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
119
- validation_file: Optional[str] = field(
120
- default=None,
121
- metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
122
- )
123
- train_ref_file: Optional[str] = field(
124
- default=None,
125
- metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
126
- )
127
- validation_ref_file: Optional[str] = field(
128
- default=None,
129
- metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
130
- )
131
- overwrite_cache: bool = field(
132
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
133
- )
134
- validation_split_percentage: Optional[int] = field(
135
- default=5,
136
- metadata={
137
- "help": "The percentage of the train set used as validation set in case there's no validation split"
138
- },
139
- )
140
- max_seq_length: Optional[int] = field(
141
- default=None,
142
- metadata={
143
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
144
- "than this will be truncated. Default to the max input length of the model."
145
- },
146
- )
147
- preprocessing_num_workers: Optional[int] = field(
148
- default=None,
149
- metadata={"help": "The number of processes to use for the preprocessing."},
150
- )
151
- mlm_probability: float = field(
152
- default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
153
- )
154
- pad_to_max_length: bool = field(
155
- default=False,
156
- metadata={
157
- "help": "Whether to pad all samples to `max_seq_length`. "
158
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
159
- },
160
- )
161
- line_by_line: bool = field(
162
- default=False,
163
- metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
164
- )
165
- text_column_name: str = field(
166
- default="text", metadata={"help": "The name of the column to retrieve the training text."}
167
- )
168
- shuffle_buffer_size: int = field(
169
- default=10000, metadata={"help": "The number of examples to pre-load for shuffling."}
170
- )
171
- num_train_steps: int = field(default=50000, metadata={"help": "The number of training steps."})
172
- num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
173
-
174
- def __post_init__(self):
175
- if self.dataset_name is None and self.train_file is None and self.validation_file is None:
176
- raise ValueError("Need either a dataset name or a training/validation file.")
177
- else:
178
- if self.train_file is not None:
179
- extension = self.train_file.split(".")[-1]
180
- assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
181
- if self.validation_file is not None:
182
- extension = self.validation_file.split(".")[-1]
183
- assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
184
-
185
-
186
- @flax.struct.dataclass
187
- class FlaxDataCollatorForLanguageModeling:
188
- """
189
- Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
190
- are not all of the same length.
191
- Args:
192
- tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
193
- The tokenizer used for encoding the data.
194
- mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
195
- The probability with which to (randomly) mask tokens in the input.
196
- .. note::
197
- For best performance, this data collator should be used with a dataset having items that are dictionaries or
198
- BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
199
- :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
200
- argument :obj:`return_special_tokens_mask=True`.
201
- """
202
-
203
- tokenizer: PreTrainedTokenizerBase
204
- mlm_probability: float = 0.15
205
-
206
- def __post_init__(self):
207
- if self.tokenizer.mask_token is None:
208
- raise ValueError(
209
- "This tokenizer does not have a mask token which is necessary for masked language modeling. "
210
- "You should pass `mlm=False` to train on causal language modeling instead."
211
- )
212
-
213
- def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
214
- # Handle dict or lists with proper padding and conversion to tensor.
215
- batch = self.tokenizer.pad(examples, return_tensors=TensorType.NUMPY)
216
-
217
- # If special token mask has been preprocessed, pop it from the dict.
218
- special_tokens_mask = batch.pop("special_tokens_mask", None)
219
-
220
- batch["input_ids"], batch["labels"] = self.mask_tokens(
221
- batch["input_ids"], special_tokens_mask=special_tokens_mask
222
- )
223
- return batch
224
-
225
- def mask_tokens(
226
- self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
227
- ) -> Tuple[jnp.ndarray, jnp.ndarray]:
228
- """
229
- Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
230
- """
231
- labels = inputs.copy()
232
- # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
233
- probability_matrix = np.full(labels.shape, self.mlm_probability)
234
- special_tokens_mask = special_tokens_mask.astype("bool")
235
-
236
- probability_matrix[special_tokens_mask] = 0.0
237
- masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
238
- labels[~masked_indices] = -100 # We only compute loss on masked tokens
239
-
240
- # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
241
- indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
242
- inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
243
-
244
- # 10% of the time, we replace masked input tokens with random word
245
- indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
246
- indices_random &= masked_indices & ~indices_replaced
247
-
248
- random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
249
- inputs[indices_random] = random_words[indices_random]
250
-
251
- # The rest of the time (10% of the time) we keep the masked input tokens unchanged
252
- return inputs, labels
253
-
254
-
255
- def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
256
- num_samples = len(samples_idx)
257
- samples_to_remove = num_samples % batch_size
258
-
259
- if samples_to_remove != 0:
260
- samples_idx = samples_idx[:-samples_to_remove]
261
- sections_split = num_samples // batch_size
262
- batch_idx = np.split(samples_idx, sections_split)
263
- return batch_idx
264
-
265
-
266
- def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
267
- """
268
- The training iterator is advanced so that after groupifying the samples,
269
- `num_samples` of length `max_seq_length` are returned.
270
- """
271
- num_total_tokens = max_seq_length * num_samples
272
- samples = defaultdict(list)
273
-
274
- i = 0
275
- while i < num_total_tokens:
276
- tokenized_samples = next(train_iterator)
277
- i += len(tokenized_samples["input_ids"])
278
-
279
- # concatenate tokenized samples to list
280
- samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
281
-
282
- # Concatenated tokens are split to lists of length `max_seq_length`.
283
- # Note that remainedr of % max_seq_length are thrown away.
284
- def group_texts(examples):
285
- result = {
286
- k: [t[i : i + max_seq_length] for i in range(0, num_total_tokens, max_seq_length)]
287
- for k, t in examples.items()
288
- }
289
- return result
290
-
291
- grouped_samples = group_texts(samples)
292
- return grouped_samples
293
-
294
-
295
- def write_train_metric(summary_writer, train_metrics, train_time, step):
296
- summary_writer.scalar("train_time", train_time, step)
297
-
298
- train_metrics = get_metrics(train_metrics)
299
- for key, vals in train_metrics.items():
300
- tag = f"train_{key}"
301
- for i, val in enumerate(vals):
302
- summary_writer.scalar(tag, val, step - len(vals) + i + 1)
303
-
304
-
305
- def write_eval_metric(summary_writer, eval_metrics, step):
306
- for metric_name, value in eval_metrics.items():
307
- summary_writer.scalar(f"eval_{metric_name}", value, step)
308
-
309
-
310
- if __name__ == "__main__":
311
- # See all possible arguments in src/transformers/training_args.py
312
- # or by passing the --help flag to this script.
313
- # We now keep distinct sets of args, for a cleaner separation of concerns.
314
-
315
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
316
- if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
317
- # If we pass only one argument to the script and it's the path to a json file,
318
- # let's parse it to get our arguments.
319
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
320
- else:
321
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
322
-
323
- if (
324
- os.path.exists(training_args.output_dir)
325
- and os.listdir(training_args.output_dir)
326
- and training_args.do_train
327
- and not training_args.overwrite_output_dir
328
- ):
329
- raise ValueError(
330
- f"Output directory ({training_args.output_dir}) already exists and is not empty."
331
- "Use --overwrite_output_dir to overcome."
332
- )
333
-
334
- # Setup logging
335
- logging.basicConfig(
336
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
337
- level="INFO",
338
- datefmt="[%X]",
339
- )
340
-
341
- # Log on each process the small summary:
342
- logger = logging.getLogger(__name__)
343
- logger.warning(
344
- f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
345
- + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
346
- )
347
-
348
- # Set the verbosity to info of the Transformers logger (on main process only):
349
- logger.info(f"Training/evaluation parameters {training_args}")
350
-
351
- # Set seed before initializing model.
352
- set_seed(training_args.seed)
353
-
354
- # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
355
- # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
356
- # (the dataset will be downloaded automatically from the datasets Hub).
357
- #
358
- # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
359
- # 'text' is found. You can easily tweak this behavior (see below).
360
- if data_args.dataset_name is not None:
361
- # Downloading and loading a dataset from the hub.
362
- # dataset = load_dataset(
363
- # data_args.dataset_name,
364
- # data_args.dataset_config_name,
365
- # cache_dir=model_args.cache_dir,
366
- # streaming=True,
367
- # split="train",
368
- # )
369
-
370
- dataset = load_dataset("mc4", "da", split="train", streaming=True)
371
- # norwegian_dataset = load_dataset("mc4", "no", split="train", streaming=True)
372
- # swedish_dataset = load_dataset("mc4", "sv", split="train", streaming=True)
373
-
374
- # danish_dataset_subset = danish_dataset.take(int(24.1e6))
375
- # norwegian_dataset_subset = norwegian_dataset.take(int(24.1e6))
376
- # swedish_dataset_subset = swedish_dataset.take(int(24.1e6))
377
-
378
- # dataset = interleave_datasets(
379
- # [danish_dataset_subset, norwegian_dataset_subset, swedish_dataset_subset], probabilities=[0.34, 0.33, 0.33]
380
- # )
381
-
382
- if model_args.config_name:
383
- config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
384
- elif model_args.model_name_or_path:
385
- print(f"Setting config from path: {model_args.model_name_or_path}")
386
- config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
387
- else:
388
- config = CONFIG_MAPPING[model_args.model_type]()
389
- logger.warning("You are instantiating a new config instance from scratch.")
390
-
391
- if model_args.tokenizer_name:
392
- tokenizer = AutoTokenizer.from_pretrained(
393
- model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
394
- )
395
- elif model_args.model_name_or_path:
396
- tokenizer = AutoTokenizer.from_pretrained(
397
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
398
- )
399
- else:
400
- raise ValueError(
401
- "You are instantiating a new tokenizer from scratch. This is not supported by this script."
402
- "You can do it from another script, save it, and load it from here, using --tokenizer_name."
403
- )
404
-
405
- # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
406
- # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
407
- # efficient when it receives the `special_tokens_mask`.
408
- def tokenize_function(examples):
409
- return tokenizer(examples[data_args.text_column_name], return_special_tokens_mask=True)
410
-
411
- tokenized_datasets = dataset.map(
412
- tokenize_function,
413
- batched=True,
414
- )
415
-
416
- shuffle_seed = training_args.seed
417
- tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
418
-
419
- has_tensorboard = is_tensorboard_available()
420
- if has_tensorboard and jax.process_index() == 0:
421
- try:
422
- from flax.metrics.tensorboard import SummaryWriter
423
- except ImportError as ie:
424
- has_tensorboard = False
425
- logger.warning(
426
- f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
427
- )
428
-
429
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
430
-
431
- # Data collator
432
- # This one will take care of randomly masking the tokens.
433
- data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
434
-
435
- # Initialize our training
436
- rng = jax.random.PRNGKey(training_args.seed)
437
- dropout_rngs = jax.random.split(rng, jax.local_device_count())
438
-
439
- if model_args.model_name_or_path:
440
- model = FlaxAutoModelForMaskedLM.from_pretrained(
441
- model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
442
- )
443
- else:
444
- model = FlaxAutoModelForMaskedLM.from_config(
445
- config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
446
- )
447
-
448
- # Store some constant
449
- num_epochs = int(training_args.num_train_epochs)
450
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
451
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
452
-
453
- # define number steps per stream epoch
454
- num_train_steps = data_args.num_train_steps
455
-
456
- # Create learning rate schedule
457
- warmup_fn = optax.linear_schedule(
458
- init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
459
- )
460
- decay_fn = optax.linear_schedule(
461
- init_value=training_args.learning_rate,
462
- end_value=0,
463
- transition_steps=num_train_steps - training_args.warmup_steps,
464
- )
465
- linear_decay_lr_schedule_fn = optax.join_schedules(
466
- schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
467
- )
468
-
469
- # We use Optax's "masking" functionality to not apply weight decay
470
- # to bias and LayerNorm scale parameters. decay_mask_fn returns a
471
- # mask boolean with the same structure as the parameters.
472
- # The mask is True for parameters that should be decayed.
473
- # Note that this mask is specifically adapted for FlaxBERT-like models.
474
- # For other models, one should correct the layer norm parameter naming
475
- # accordingly.
476
- def decay_mask_fn(params):
477
- flat_params = traverse_util.flatten_dict(params)
478
- flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
479
- return traverse_util.unflatten_dict(flat_mask)
480
-
481
- # create adam optimizer
482
- adamw = optax.adamw(
483
- learning_rate=linear_decay_lr_schedule_fn,
484
- b1=training_args.adam_beta1,
485
- b2=training_args.adam_beta2,
486
- eps=training_args.adam_epsilon,
487
- weight_decay=training_args.weight_decay,
488
- mask=decay_mask_fn,
489
- )
490
-
491
- # Setup train state
492
- state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
493
-
494
- # Define gradient update step fn
495
- def train_step(state, batch, dropout_rng):
496
- dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
497
-
498
- def loss_fn(params):
499
- labels = batch.pop("labels")
500
-
501
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
502
-
503
- # compute loss, ignore padded input tokens
504
- label_mask = jnp.where(labels > 0, 1.0, 0.0)
505
- loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
506
-
507
- # take average
508
- loss = loss.sum() / label_mask.sum()
509
-
510
- return loss
511
-
512
- grad_fn = jax.value_and_grad(loss_fn)
513
- loss, grad = grad_fn(state.params)
514
- grad = jax.lax.pmean(grad, "batch")
515
- new_state = state.apply_gradients(grads=grad)
516
-
517
- metrics = jax.lax.pmean(
518
- {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
519
- )
520
-
521
- return new_state, metrics, new_dropout_rng
522
-
523
- # Create parallel version of the train step
524
- p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
525
-
526
- # Define eval fn
527
- def eval_step(params, batch):
528
- labels = batch.pop("labels")
529
-
530
- logits = model(**batch, params=params, train=False)[0]
531
-
532
- # compute loss, ignore padded input tokens
533
- label_mask = jnp.where(labels > 0, 1.0, 0.0)
534
- loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
535
-
536
- # compute accuracy
537
- accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
538
-
539
- # summarize metrics
540
- metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
541
- metrics = jax.lax.psum(metrics, axis_name="batch")
542
-
543
- return metrics
544
-
545
- p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
546
-
547
- # Replicate the train state on each device
548
- state = jax_utils.replicate(state)
549
-
550
- train_time = 0
551
- train_start = time.time()
552
- train_metrics = []
553
- eval_metrics = []
554
-
555
- training_iter = iter(tokenized_datasets)
556
-
557
- max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
558
- eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
559
-
560
- steps = tqdm(range(num_train_steps), desc="Training...", position=0)
561
- for step in range(num_train_steps):
562
- # ======================== Training ================================
563
- try:
564
- samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
565
- except StopIteration:
566
- # Once the end of the dataset stream is reached, the training iterator
567
- # is reinitialized and reshuffled and a new eval dataset is randomely chosen.
568
- shuffle_seed += 1
569
- tokenized_datasets.set_epoch(shuffle_seed)
570
-
571
- training_iter = iter(tokenized_datasets)
572
-
573
- eval_dataset = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
574
- samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
575
-
576
- # process input samples
577
- model_inputs = data_collator(samples)
578
-
579
- # Model forward
580
- model_inputs = shard(model_inputs.data)
581
- state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
582
-
583
- train_metrics.append(train_metric)
584
-
585
- if step % training_args.logging_steps == 0 and step > 0:
586
- steps.write(
587
- f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
588
- )
589
- train_time += time.time() - train_start
590
- if has_tensorboard and jax.process_index() == 0:
591
- write_train_metric(summary_writer, train_metrics, train_time, step)
592
- train_metrics = []
593
-
594
- # ======================== Evaluating ==============================
595
- if step % training_args.eval_steps == 0 and step > 0:
596
- eval_samples_idx = jnp.arange(data_args.num_eval_samples)
597
- eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
598
-
599
- for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
600
- # process input samples
601
- batch_eval_samples = {k: [v[idx] for idx in batch_idx] for k, v in eval_samples.items()}
602
- model_inputs = data_collator(batch_eval_samples)
603
-
604
- # Model forward
605
- model_inputs = shard(model_inputs.data)
606
- metrics = p_eval_step(state.params, model_inputs)
607
- eval_metrics.append(metrics)
608
-
609
- # normalize eval metrics
610
- eval_metrics = get_metrics(eval_metrics)
611
- eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
612
- eval_normalizer = eval_metrics.pop("normalizer")
613
- eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
614
-
615
- # Update progress bar
616
- steps.desc = f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
617
-
618
- if has_tensorboard and jax.process_index() == 0:
619
- write_eval_metric(summary_writer, eval_metrics, step)
620
- eval_metrics = []
621
-
622
- # Saving at each save_step
623
- if step % training_args.save_steps == 0 and step > 0:
624
- # save checkpoint after each epoch and push checkpoint to the hub
625
- if jax.process_index() == 0:
626
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
627
- model.save_pretrained(
628
- training_args.output_dir,
629
- params=params,
630
- push_to_hub=training_args.push_to_hub,
631
- commit_message=f"Saving weights and logs of step {step+1}",
632
- )
633
-
634
- # update tqdm bar
635
- steps.update(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/gigaword.py DELETED
@@ -1 +0,0 @@
1
- '''Functions and classes for loading/streaming the relevant GigaWord datasets'''
 
 
src/mc4.py DELETED
@@ -1 +0,0 @@
1
- '''Functions and classes for loading/streaming the relevant mC4 datasets'''
 
 
src/scandi_run_mlm_flax.py DELETED
@@ -1,681 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2021 The HuggingFace Team All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
- text file or a dataset.
19
-
20
- Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
- https://huggingface.co/models?filter=masked-lm
22
- """
23
- import logging
24
- import os
25
- import sys
26
- import time
27
- from dataclasses import dataclass, field
28
-
29
- # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
30
- from pathlib import Path
31
- from typing import Dict, List, Optional, Tuple
32
-
33
- import numpy as np
34
- from datasets import load_dataset, concatenate_datasets, interleave_datasets
35
- from tqdm import tqdm
36
-
37
- import flax
38
- import jax
39
- import jax.numpy as jnp
40
- import optax
41
- from flax import jax_utils, traverse_util
42
- from flax.training import train_state
43
- from flax.training.common_utils import get_metrics, onehot, shard
44
- from transformers import (
45
- CONFIG_MAPPING,
46
- FLAX_MODEL_FOR_MASKED_LM_MAPPING,
47
- AutoConfig,
48
- AutoTokenizer,
49
- FlaxAutoModelForMaskedLM,
50
- HfArgumentParser,
51
- PreTrainedTokenizerBase,
52
- TensorType,
53
- TrainingArguments,
54
- is_tensorboard_available,
55
- set_seed,
56
- )
57
-
58
-
59
- # Cache the result
60
- has_tensorboard = is_tensorboard_available()
61
- if has_tensorboard:
62
- try:
63
- from flax.metrics.tensorboard import SummaryWriter
64
- except ImportError as ie:
65
- has_tensorboard = False
66
- print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
67
-
68
- else:
69
- print(
70
- "Unable to display metrics through TensorBoard because the package is not installed: "
71
- "Please run pip install tensorboard to enable."
72
- )
73
-
74
-
75
- MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
76
- MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
77
-
78
-
79
- @dataclass
80
- class ModelArguments:
81
- """
82
- Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
83
- """
84
-
85
- model_name_or_path: Optional[str] = field(
86
- default=None,
87
- metadata={
88
- "help": "The model checkpoint for weights initialization."
89
- "Don't set if you want to train a model from scratch."
90
- },
91
- )
92
- model_type: Optional[str] = field(
93
- default=None,
94
- metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
95
- )
96
- config_name: Optional[str] = field(
97
- default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
98
- )
99
- tokenizer_name: Optional[str] = field(
100
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
101
- )
102
- cache_dir: Optional[str] = field(
103
- default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
104
- )
105
- use_fast_tokenizer: bool = field(
106
- default=True,
107
- metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
108
- )
109
- dtype: Optional[str] = field(
110
- default="float32",
111
- metadata={
112
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
113
- },
114
- )
115
-
116
-
117
- @dataclass
118
- class DataTrainingArguments:
119
- """
120
- Arguments pertaining to what data we are going to input our model for training and eval.
121
- """
122
-
123
- dataset_name: Optional[str] = field(
124
- default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
125
- )
126
- dataset_config_name: Optional[str] = field(
127
- default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
128
- )
129
- train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
130
- validation_file: Optional[str] = field(
131
- default=None,
132
- metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
133
- )
134
- train_ref_file: Optional[str] = field(
135
- default=None,
136
- metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
137
- )
138
- validation_ref_file: Optional[str] = field(
139
- default=None,
140
- metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
141
- )
142
- overwrite_cache: bool = field(
143
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
144
- )
145
- validation_split_percentage: Optional[int] = field(
146
- default=5,
147
- metadata={
148
- "help": "The percentage of the train set used as validation set in case there's no validation split"
149
- },
150
- )
151
- max_seq_length: Optional[int] = field(
152
- default=None,
153
- metadata={
154
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
155
- "than this will be truncated. Default to the max input length of the model."
156
- },
157
- )
158
- preprocessing_num_workers: Optional[int] = field(
159
- default=None,
160
- metadata={"help": "The number of processes to use for the preprocessing."},
161
- )
162
- mlm_probability: float = field(
163
- default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
164
- )
165
- pad_to_max_length: bool = field(
166
- default=False,
167
- metadata={
168
- "help": "Whether to pad all samples to `max_seq_length`. "
169
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
170
- },
171
- )
172
- line_by_line: bool = field(
173
- default=False,
174
- metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
175
- )
176
-
177
- def __post_init__(self):
178
- if self.dataset_name is None and self.train_file is None and self.validation_file is None:
179
- raise ValueError("Need either a dataset name or a training/validation file.")
180
- else:
181
- if self.train_file is not None:
182
- extension = self.train_file.split(".")[-1]
183
- assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
184
- if self.validation_file is not None:
185
- extension = self.validation_file.split(".")[-1]
186
- assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
187
-
188
-
189
- @flax.struct.dataclass
190
- class FlaxDataCollatorForLanguageModeling:
191
- """
192
- Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
193
- are not all of the same length.
194
-
195
- Args:
196
- tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
197
- The tokenizer used for encoding the data.
198
- mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
199
- The probability with which to (randomly) mask tokens in the input.
200
-
201
- .. note::
202
-
203
- For best performance, this data collator should be used with a dataset having items that are dictionaries or
204
- BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
205
- :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
206
- argument :obj:`return_special_tokens_mask=True`.
207
- """
208
-
209
- tokenizer: PreTrainedTokenizerBase
210
- mlm_probability: float = 0.15
211
-
212
- def __post_init__(self):
213
- if self.tokenizer.mask_token is None:
214
- raise ValueError(
215
- "This tokenizer does not have a mask token which is necessary for masked language modeling. "
216
- "You should pass `mlm=False` to train on causal language modeling instead."
217
- )
218
-
219
- def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
220
- # Handle dict or lists with proper padding and conversion to tensor.
221
- batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
222
-
223
- # If special token mask has been preprocessed, pop it from the dict.
224
- special_tokens_mask = batch.pop("special_tokens_mask", None)
225
-
226
- batch["input_ids"], batch["labels"] = self.mask_tokens(
227
- batch["input_ids"], special_tokens_mask=special_tokens_mask
228
- )
229
- return batch
230
-
231
- def mask_tokens(
232
- self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
233
- ) -> Tuple[jnp.ndarray, jnp.ndarray]:
234
- """
235
- Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
236
- """
237
- labels = inputs.copy()
238
- # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
239
- probability_matrix = np.full(labels.shape, self.mlm_probability)
240
- special_tokens_mask = special_tokens_mask.astype("bool")
241
-
242
- probability_matrix[special_tokens_mask] = 0.0
243
- masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
244
- labels[~masked_indices] = -100 # We only compute loss on masked tokens
245
-
246
- # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
247
- indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
248
- inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
249
-
250
- # 10% of the time, we replace masked input tokens with random word
251
- indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
252
- indices_random &= masked_indices & ~indices_replaced
253
-
254
- random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
255
- inputs[indices_random] = random_words[indices_random]
256
-
257
- # The rest of the time (10% of the time) we keep the masked input tokens unchanged
258
- return inputs, labels
259
-
260
-
261
- def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
262
- num_samples = len(samples_idx)
263
- samples_to_remove = num_samples % batch_size
264
-
265
- if samples_to_remove != 0:
266
- samples_idx = samples_idx[:-samples_to_remove]
267
- sections_split = num_samples // batch_size
268
- batch_idx = np.split(samples_idx, sections_split)
269
- return batch_idx
270
-
271
-
272
- def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
273
- summary_writer.scalar("train_time", train_time, step)
274
-
275
- train_metrics = get_metrics(train_metrics)
276
- for key, vals in train_metrics.items():
277
- tag = f"train_{key}"
278
- for i, val in enumerate(vals):
279
- summary_writer.scalar(tag, val, step - len(vals) + i + 1)
280
-
281
- for metric_name, value in eval_metrics.items():
282
- summary_writer.scalar(f"eval_{metric_name}", value, step)
283
-
284
-
285
- if __name__ == "__main__":
286
- # See all possible arguments in src/transformers/training_args.py
287
- # or by passing the --help flag to this script.
288
- # We now keep distinct sets of args, for a cleaner separation of concerns.
289
-
290
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
291
- if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
292
- # If we pass only one argument to the script and it's the path to a json file,
293
- # let's parse it to get our arguments.
294
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
295
- else:
296
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
297
-
298
- if (
299
- os.path.exists(training_args.output_dir)
300
- and os.listdir(training_args.output_dir)
301
- and training_args.do_train
302
- and not training_args.overwrite_output_dir
303
- ):
304
- raise ValueError(
305
- f"Output directory ({training_args.output_dir}) already exists and is not empty."
306
- "Use --overwrite_output_dir to overcome."
307
- )
308
-
309
- # Setup logging
310
- logging.basicConfig(
311
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
312
- level="NOTSET",
313
- datefmt="[%X]",
314
- )
315
-
316
- # Log on each process the small summary:
317
- logger = logging.getLogger(__name__)
318
- logger.warning(
319
- f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
320
- + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
321
- )
322
-
323
- # Set the verbosity to info of the Transformers logger (on main process only):
324
- logger.info(f"Training/evaluation parameters {training_args}")
325
-
326
- # Set seed before initializing model.
327
- set_seed(training_args.seed)
328
-
329
- # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
330
- # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
331
- # (the dataset will be downloaded automatically from the datasets Hub).
332
- #
333
- # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
334
- # 'text' is found. You can easily tweak this behavior (see below).
335
- #
336
- # In distributed training, the load_dataset function guarantees that only one local process can concurrently
337
- # download the dataset.
338
- # if data_args.dataset_name is not None:
339
- # Downloading and loading a dataset from the hub.
340
- # datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
341
-
342
- # # Downloading the scandinavian datasets and concatenating them
343
- # danish_dataset = load_dataset("mc4", "da", split="train[:24100000]") # , download_mode="force_redownload")
344
- # norwegian_dataset = load_dataset("mc4", "no", split="train[:24100000]") # , download_mode="force_redownload")
345
- # swedish_dataset = load_dataset("mc4", "sv", split="train[:24100000]") # , download_mode="force_redownload")
346
- # all_datasets = concatenate_datasets([danish_dataset, norwegian_dataset, swedish_dataset])
347
- # datasets = all_datasets.shuffle()
348
- # datasets = datasets.select(range(1000))
349
- # datasets = datasets.train_test_split(test_size=0.01)
350
- # datasets["validation"] = datasets["test"]
351
-
352
- # Downloading the scandinavian datasets and interleaving them
353
- danish_dataset = load_dataset('mc4', 'da', split="train[:24100000]", streaming=True)
354
- norwegian_dataset = load_dataset('mc4', 'no', split="train[:24100000]", streaming=True)
355
- swedish_dataset = load_dataset('mc4', 'sv', split="train[:24100000]", streaming=True)
356
-
357
- dataset = interleave_datasets([danish_dataset , norwegian_dataset , swedish_dataset], probabilities=[0.33, 0.33, 0.33])
358
- dataset = dataset.train_test_split(test_size=0.01)
359
- dataset["validation"] = dataset["test"]
360
-
361
-
362
- if "validation" not in datasets.keys():
363
- datasets["validation"] = load_dataset(
364
- data_args.dataset_name,
365
- data_args.dataset_config_name,
366
- split=f"train[:{data_args.validation_split_percentage}%]",
367
- cache_dir=model_args.cache_dir,
368
- )
369
- datasets["train"] = load_dataset(
370
- data_args.dataset_name,
371
- data_args.dataset_config_name,
372
- split=f"train[{data_args.validation_split_percentage}%:]",
373
- cache_dir=model_args.cache_dir,
374
- )
375
- # else:
376
- # data_files = {}
377
- # if data_args.train_file is not None:
378
- # data_files["train"] = data_args.train_file
379
- # if data_args.validation_file is not None:
380
- # data_files["validation"] = data_args.validation_file
381
- # extension = data_args.train_file.split(".")[-1]
382
- # if extension == "txt":
383
- # extension = "text"
384
-
385
- # datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
386
- # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
387
- # https://huggingface.co/docs/datasets/loading_datasets.html.
388
-
389
- # Load pretrained model and tokenizer
390
-
391
- # Distributed training:
392
- # The .from_pretrained methods guarantee that only one local process can concurrently
393
- # download model & vocab.
394
- if model_args.config_name:
395
- config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
396
- elif model_args.model_name_or_path:
397
- config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
398
- else:
399
- config = CONFIG_MAPPING[model_args.model_type]()
400
- logger.warning("You are instantiating a new config instance from scratch.")
401
-
402
- if model_args.tokenizer_name:
403
- tokenizer = AutoTokenizer.from_pretrained(
404
- model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
405
- )
406
- elif model_args.model_name_or_path:
407
- tokenizer = AutoTokenizer.from_pretrained(
408
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
409
- )
410
- else:
411
- raise ValueError(
412
- "You are instantiating a new tokenizer from scratch. This is not supported by this script."
413
- "You can do it from another script, save it, and load it from here, using --tokenizer_name."
414
- )
415
-
416
- # Preprocessing the datasets.
417
- # First we tokenize all the texts.
418
- if training_args.do_train:
419
- column_names = datasets["train"].column_names
420
- else:
421
- column_names = datasets["validation"].column_names
422
- text_column_name = "text" if "text" in column_names else column_names[0]
423
-
424
- max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
425
-
426
- if data_args.line_by_line:
427
- # When using line_by_line, we just tokenize each nonempty line.
428
- padding = "max_length" if data_args.pad_to_max_length else False
429
-
430
- def tokenize_function(examples):
431
- # Remove empty lines
432
- examples = [line for line in examples if len(line) > 0 and not line.isspace()]
433
- return tokenizer(
434
- examples,
435
- return_special_tokens_mask=True,
436
- padding=padding,
437
- truncation=True,
438
- max_length=max_seq_length,
439
- )
440
-
441
- tokenized_datasets = datasets.map(
442
- tokenize_function,
443
- input_columns=[text_column_name],
444
- batched=True,
445
- num_proc=data_args.preprocessing_num_workers,
446
- remove_columns=column_names,
447
- load_from_cache_file=not data_args.overwrite_cache,
448
- )
449
-
450
- else:
451
- # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
452
- # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
453
- # efficient when it receives the `special_tokens_mask`.
454
- def tokenize_function(examples):
455
- return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
456
-
457
- tokenized_datasets = datasets.map(
458
- tokenize_function,
459
- batched=True,
460
- num_proc=data_args.preprocessing_num_workers,
461
- remove_columns=column_names,
462
- load_from_cache_file=not data_args.overwrite_cache,
463
- )
464
-
465
- # Main data processing function that will concatenate all texts from our dataset and generate chunks of
466
- # max_seq_length.
467
- def group_texts(examples):
468
- # Concatenate all texts.
469
- concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
470
- total_length = len(concatenated_examples[list(examples.keys())[0]])
471
- # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
472
- # customize this part to your needs.
473
- total_length = (total_length // max_seq_length) * max_seq_length
474
- # Split by chunks of max_len.
475
- result = {
476
- k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
477
- for k, t in concatenated_examples.items()
478
- }
479
- return result
480
-
481
- # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
482
- # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
483
- # might be slower to preprocess.
484
- #
485
- # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
486
- # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
487
- tokenized_datasets = tokenized_datasets.map(
488
- group_texts,
489
- batched=True,
490
- num_proc=data_args.preprocessing_num_workers,
491
- load_from_cache_file=not data_args.overwrite_cache,
492
- )
493
-
494
- # Enable tensorboard only on the master node
495
- if has_tensorboard and jax.process_index() == 0:
496
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
497
-
498
- # Data collator
499
- # This one will take care of randomly masking the tokens.
500
- data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
501
-
502
- # Initialize our training
503
- rng = jax.random.PRNGKey(training_args.seed)
504
- dropout_rngs = jax.random.split(rng, jax.local_device_count())
505
-
506
- model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
507
-
508
- # Store some constant
509
- num_epochs = int(training_args.num_train_epochs)
510
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
511
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
512
-
513
- num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
514
-
515
- # Create learning rate schedule
516
- warmup_fn = optax.linear_schedule(
517
- init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
518
- )
519
- decay_fn = optax.linear_schedule(
520
- init_value=training_args.learning_rate,
521
- end_value=0,
522
- transition_steps=num_train_steps - training_args.warmup_steps,
523
- )
524
- linear_decay_lr_schedule_fn = optax.join_schedules(
525
- schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
526
- )
527
-
528
- # We use Optax's "masking" functionality to not apply weight decay
529
- # to bias and LayerNorm scale parameters. decay_mask_fn returns a
530
- # mask boolean with the same structure as the parameters.
531
- # The mask is True for parameters that should be decayed.
532
- # Note that this mask is specifically adapted for FlaxBERT-like models.
533
- # For other models, one should correct the layer norm parameter naming
534
- # accordingly.
535
- def decay_mask_fn(params):
536
- flat_params = traverse_util.flatten_dict(params)
537
- flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
538
- return traverse_util.unflatten_dict(flat_mask)
539
-
540
- # create adam optimizer
541
- adamw = optax.adamw(
542
- learning_rate=linear_decay_lr_schedule_fn,
543
- b1=training_args.adam_beta1,
544
- b2=training_args.adam_beta2,
545
- eps=1e-8,
546
- weight_decay=training_args.weight_decay,
547
- mask=decay_mask_fn,
548
- )
549
-
550
- # Setup train state
551
- state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
552
-
553
- # Define gradient update step fn
554
- def train_step(state, batch, dropout_rng):
555
- dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
556
-
557
- def loss_fn(params):
558
- labels = batch.pop("labels")
559
-
560
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
561
-
562
- # compute loss, ignore padded input tokens
563
- label_mask = jnp.where(labels > 0, 1.0, 0.0)
564
- loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
565
-
566
- # take average
567
- loss = loss.sum() / label_mask.sum()
568
-
569
- return loss
570
-
571
- grad_fn = jax.value_and_grad(loss_fn)
572
- loss, grad = grad_fn(state.params)
573
- grad = jax.lax.pmean(grad, "batch")
574
- new_state = state.apply_gradients(grads=grad)
575
-
576
- metrics = jax.lax.pmean(
577
- {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
578
- )
579
-
580
- return new_state, metrics, new_dropout_rng
581
-
582
- # Create parallel version of the train step
583
- p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
584
-
585
- # Define eval fn
586
- def eval_step(params, batch):
587
- labels = batch.pop("labels")
588
-
589
- logits = model(**batch, params=params, train=False)[0]
590
-
591
- # compute loss, ignore padded input tokens
592
- label_mask = jnp.where(labels > 0, 1.0, 0.0)
593
- loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
594
-
595
- # compute accuracy
596
- accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
597
-
598
- # summarize metrics
599
- metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
600
- metrics = jax.lax.psum(metrics, axis_name="batch")
601
-
602
- return metrics
603
-
604
- p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
605
-
606
- # Replicate the train state on each device
607
- state = jax_utils.replicate(state)
608
-
609
- train_time = 0
610
- epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
611
- for epoch in epochs:
612
- # ======================== Training ================================
613
- train_start = time.time()
614
- train_metrics = []
615
-
616
- # Create sampling rng
617
- rng, input_rng = jax.random.split(rng)
618
-
619
- # Generate an epoch by shuffling sampling indices from the train dataset
620
- num_train_samples = len(tokenized_datasets["train"])
621
- train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
622
- train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
623
-
624
- # Gather the indexes for creating the batch and do a training step
625
- for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
626
- samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
627
- model_inputs = data_collator(samples, pad_to_multiple_of=16)
628
-
629
- # Model forward
630
- model_inputs = shard(model_inputs.data)
631
- state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
632
- train_metrics.append(train_metric)
633
-
634
- train_time += time.time() - train_start
635
-
636
- epochs.write(
637
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
638
- )
639
-
640
- # ======================== Evaluating ==============================
641
- num_eval_samples = len(tokenized_datasets["validation"])
642
- eval_samples_idx = jnp.arange(num_eval_samples)
643
- eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
644
-
645
- eval_metrics = []
646
- for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
647
- samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
648
- model_inputs = data_collator(samples, pad_to_multiple_of=16)
649
-
650
- # Model forward
651
- model_inputs = shard(model_inputs.data)
652
- metrics = p_eval_step(state.params, model_inputs)
653
- eval_metrics.append(metrics)
654
-
655
- # normalize eval metrics
656
- eval_metrics = get_metrics(eval_metrics)
657
- eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
658
- eval_normalizer = eval_metrics.pop("normalizer")
659
- eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
660
-
661
- # Update progress bar
662
- epochs.desc = (
663
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
664
- )
665
-
666
- # Save metrics
667
- if has_tensorboard and jax.process_index() == 0:
668
- cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
669
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
670
-
671
- # save checkpoint after each epoch and push checkpoint to the hub
672
- if jax.process_index() == 0:
673
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
674
- model.save_pretrained(
675
- training_args.output_dir,
676
- params=params,
677
- push_to_hub=training_args.push_to_hub,
678
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
679
- )
680
-
681
- # python3 ./roberta-large-scandi/src/scandi_run_mlm_flax.py --output_dir="./roberta-large-scandi/runs" --model_type="roberta" --config_name="${MODEL_DIR}" --tokenizer_name="${MODEL_DIR}" --dataset_name="mc4" --max_seq_length="128" --weight_decay="0.01" --per_device_train_batch_size="128" --per_device_eval_batch_size="128" --learning_rate="3e-4" --warmup_steps="1000" --overwrite_output_dir --pad_to_max_length --num_train_epochs="10" --adam_beta1="0.9" --adam_beta2="0.98"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/tokenizer.py DELETED
@@ -1,27 +0,0 @@
1
- '''Training script for tokenizer'''
2
-
3
- from datasets import load_dataset
4
- from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
5
- from .utils import model_dir
6
-
7
- # load dataset
8
- dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")
9
-
10
- # Instantiate tokenizer
11
- tokenizer = ByteLevelBPETokenizer()
12
-
13
- def batch_iterator(batch_size=1000):
14
- for i in range(0, len(dataset), batch_size):
15
- yield dataset[i: i + batch_size]["text"]
16
-
17
- # Customized training
18
- tokenizer.train_from_iterator(
19
- batch_iterator(),
20
- vocab_size=50265,
21
- min_frequency=2,
22
- special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
23
- )
24
-
25
- # Save files to disk
26
- tokenizer_path = model_dir / 'tokenizer.json'
27
- tokenizer.save(str(tokenizer_path))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/train_tokenizer.py DELETED
@@ -1,44 +0,0 @@
1
- from datasets import load_dataset, concatenate_datasets
2
- from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
3
-
4
- model_dir = "./scandinavian" # ${MODEL_DIR}
5
-
6
- # load dataset
7
- # dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")
8
- # mc4_subset_with_five_languages = load_dataset("mc4", languages=["en", "fr", "es", "de", "zh"])
9
- # yoruba_dataset = load_dataset("mc4", "yo", split="train[0:10]")
10
- # yoruba_dataset2 = load_dataset("mc4", "yo", split="train[10:20]")
11
-
12
- danish_dataset = load_dataset("mc4", "da") # , download_mode="force_redownload")
13
- norwegian_dataset = load_dataset("mc4", "no") # , download_mode="force_redownload")
14
- swedish_dataset = load_dataset("mc4", "sv") # , download_mode="force_redownload")
15
-
16
- # all_datasets = concatenate_datasets([yoruba_dataset, yoruba_dataset2])
17
- all_datasets = concatenate_datasets([danish_dataset, norwegian_dataset, swedish_dataset])
18
- all_datasets = all_datasets.shuffle()
19
-
20
- # Instantiate tokenizer
21
- tokenizer = ByteLevelBPETokenizer()
22
-
23
-
24
- def batch_iterator(batch_size=1000):
25
- for i in range(0, len(all_datasets), batch_size):
26
- yield all_datasets[i : i + batch_size]["text"]
27
-
28
-
29
- # Customized training
30
- tokenizer.train_from_iterator(
31
- batch_iterator(),
32
- vocab_size=50265,
33
- min_frequency=2,
34
- special_tokens=[
35
- "<s>",
36
- "<pad>",
37
- "</s>",
38
- "<unk>",
39
- "<mask>",
40
- ],
41
- )
42
-
43
- # Save files to disk
44
- tokenizer.save(f"{model_dir}/tokenizer.json")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils.py DELETED
@@ -1,7 +0,0 @@
1
- '''Utility functions and variables used in other scripts'''
2
-
3
- from pathlib import Path
4
-
5
-
6
- root_dir = Path(__file__).parent.parent
7
- model_dir = '' # TODO: Needs to be set
 
 
 
 
 
 
 
 
tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json DELETED
@@ -1 +0,0 @@
1
- {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "add_prefix_space": false, "errors": "replace", "sep_token": "</s>", "cls_token": "<s>", "pad_token": "<pad>", "mask_token": "<mask>", "special_tokens_map_file": null, "name_or_path": "./", "tokenizer_class": "RobertaTokenizer"}
 
 
vocab.json DELETED
The diff for this file is too large to render. See raw diff