versae commited on
Commit
9c3de9e
1 Parent(s): da38e4c

Saving weights and logs of step 1000

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. config.json +26 -0
  3. eval_results.json +5 -0
  4. events.out.tfevents.1642203685.t1v-n-eedfb410-w-0.10537.0.v2 +3 -0
  5. events.out.tfevents.1642204242.t1v-n-eedfb410-w-0.profile-empty +3 -0
  6. events.out.tfevents.1642608722.t1v-n-eedfb410-w-0.1271442.0.v2 +3 -0
  7. flax_model.msgpack +3 -0
  8. merges.txt +0 -0
  9. run_mlm_flax.py +815 -0
  10. special_tokens_map.json +1 -0
  11. tokenizer.json +0 -0
  12. tokenizer_config.json +1 -0
  13. train.128.sh +25 -0
  14. train.512.sh +26 -0
  15. vocab.json +0 -0
  16. wandb/debug-internal.log +1 -0
  17. wandb/debug.log +1 -0
  18. wandb/latest-run +1 -0
  19. wandb/run-20220114_212855-32qdb4k5/files/code/run_mlm_flax.py +815 -0
  20. wandb/run-20220114_212855-32qdb4k5/files/config.yaml +152 -0
  21. wandb/run-20220114_212855-32qdb4k5/files/diff.patch +0 -0
  22. wandb/run-20220114_212855-32qdb4k5/files/output.log +43 -0
  23. wandb/run-20220114_212855-32qdb4k5/files/requirements.txt +122 -0
  24. wandb/run-20220114_212855-32qdb4k5/files/wandb-metadata.json +47 -0
  25. wandb/run-20220114_212855-32qdb4k5/files/wandb-summary.json +1 -0
  26. wandb/run-20220114_212855-32qdb4k5/logs/debug-internal.log +189 -0
  27. wandb/run-20220114_212855-32qdb4k5/logs/debug.log +150 -0
  28. wandb/run-20220114_212855-32qdb4k5/run-32qdb4k5.wandb +0 -0
  29. wandb/run-20220114_221533-24dma583/files/code/run_mlm_flax.py +815 -0
  30. wandb/run-20220114_221533-24dma583/files/config.yaml +152 -0
  31. wandb/run-20220114_221533-24dma583/files/diff.patch +0 -0
  32. wandb/run-20220114_221533-24dma583/files/output.log +43 -0
  33. wandb/run-20220114_221533-24dma583/files/requirements.txt +122 -0
  34. wandb/run-20220114_221533-24dma583/files/wandb-metadata.json +47 -0
  35. wandb/run-20220114_221533-24dma583/files/wandb-summary.json +1 -0
  36. wandb/run-20220114_221533-24dma583/logs/debug-internal.log +187 -0
  37. wandb/run-20220114_221533-24dma583/logs/debug.log +141 -0
  38. wandb/run-20220114_221533-24dma583/run-24dma583.wandb +0 -0
  39. wandb/run-20220114_234119-1zya86oe/files/code/run_mlm_flax.py +815 -0
  40. wandb/run-20220114_234119-1zya86oe/files/config.yaml +152 -0
  41. wandb/run-20220114_234119-1zya86oe/files/diff.patch +0 -0
  42. wandb/run-20220114_234119-1zya86oe/files/output.log +3 -0
  43. wandb/run-20220114_234119-1zya86oe/files/requirements.txt +122 -0
  44. wandb/run-20220114_234119-1zya86oe/files/wandb-metadata.json +47 -0
  45. wandb/run-20220114_234119-1zya86oe/files/wandb-summary.json +1 -0
  46. wandb/run-20220114_234119-1zya86oe/logs/debug-internal.log +3 -0
  47. wandb/run-20220114_234119-1zya86oe/logs/debug.log +168 -0
  48. wandb/run-20220114_234119-1zya86oe/run-1zya86oe.wandb +3 -0
  49. wandb/run-20220119_161158-274aad95/files/code/run_mlm_flax.py +815 -0
  50. wandb/run-20220119_161158-274aad95/files/config.yaml +147 -0
.gitattributes CHANGED
@@ -25,3 +25,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ wandb/run-20220114_234119-1zya86oe/files/output.log filter=lfs diff=lfs merge=lfs -text
29
+ wandb/run-20220114_234119-1zya86oe/logs/debug-internal.log filter=lfs diff=lfs merge=lfs -text
30
+ wandb/run-20220114_234119-1zya86oe/run-1zya86oe.wandb filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./",
3
+ "architectures": [
4
+ "RobertaForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 768,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 514,
17
+ "model_type": "roberta",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "position_embedding_type": "absolute",
22
+ "transformers_version": "4.16.0.dev0",
23
+ "type_vocab_size": 1,
24
+ "use_cache": true,
25
+ "vocab_size": 50265
26
+ }
eval_results.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "eval_accuracy": 0.6885889057979672,
3
+ "eval_loss": 1.430497475427537,
4
+ "eval_perplexity": 4.180778509252052
5
+ }
events.out.tfevents.1642203685.t1v-n-eedfb410-w-0.10537.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d263549b78d9bb80e115caaffac01f4b141c31e954b64fb7a2ee59c5e4138641
3
+ size 19208160
events.out.tfevents.1642204242.t1v-n-eedfb410-w-0.profile-empty ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ac614ecef2709e4ed2bc443ce4ade10122a22097363c5eb86dfadf8e74fa7c5
3
+ size 40
events.out.tfevents.1642608722.t1v-n-eedfb410-w-0.1271442.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d194308b26bfc127e2d922f51dc9ca9914119f3b530144d116457498072ad97
3
+ size 147136
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a12b2311a330a0e74118143c453bf8ad2490d7109372302e7cf9b878f65e181
3
+ size 498796983
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
run_mlm_flax.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=fill-mask
22
+ """
23
+ import json
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ import time
29
+ from dataclasses import asdict, dataclass, field
30
+ from enum import Enum
31
+ from itertools import chain
32
+
33
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
34
+ from pathlib import Path
35
+ from typing import Dict, List, Optional, Tuple
36
+
37
+ import numpy as np
38
+ from datasets import load_dataset
39
+ from tqdm import tqdm
40
+
41
+ import flax
42
+ import jax
43
+ import jax.numpy as jnp
44
+ import optax
45
+ from flax import jax_utils, traverse_util
46
+ from flax.training import train_state
47
+ from flax.training.common_utils import get_metrics, onehot, shard
48
+ from huggingface_hub import Repository
49
+ from transformers import (
50
+ CONFIG_MAPPING,
51
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
52
+ AutoConfig,
53
+ AutoTokenizer,
54
+ FlaxAutoModelForMaskedLM,
55
+ HfArgumentParser,
56
+ PreTrainedTokenizerBase,
57
+ TensorType,
58
+ is_tensorboard_available,
59
+ set_seed,
60
+ )
61
+ from transformers.file_utils import get_full_repo_name
62
+
63
+
64
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
65
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
66
+
67
+
68
+ @dataclass
69
+ class TrainingArguments:
70
+ output_dir: str = field(
71
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
72
+ )
73
+ overwrite_output_dir: bool = field(
74
+ default=False,
75
+ metadata={
76
+ "help": (
77
+ "Overwrite the content of the output directory. "
78
+ "Use this to continue training if output_dir points to a checkpoint directory."
79
+ )
80
+ },
81
+ )
82
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
83
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
84
+ per_device_train_batch_size: int = field(
85
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
86
+ )
87
+ per_device_eval_batch_size: int = field(
88
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
89
+ )
90
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
91
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
92
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
93
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
94
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
95
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
96
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
97
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
98
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
99
+ save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
100
+ eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
101
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
102
+ push_to_hub: bool = field(
103
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
104
+ )
105
+ hub_model_id: str = field(
106
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
107
+ )
108
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
109
+
110
+ def __post_init__(self):
111
+ if self.output_dir is not None:
112
+ self.output_dir = os.path.expanduser(self.output_dir)
113
+
114
+ def to_dict(self):
115
+ """
116
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
117
+ the token values by removing their value.
118
+ """
119
+ d = asdict(self)
120
+ for k, v in d.items():
121
+ if isinstance(v, Enum):
122
+ d[k] = v.value
123
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
124
+ d[k] = [x.value for x in v]
125
+ if k.endswith("_token"):
126
+ d[k] = f"<{k.upper()}>"
127
+ return d
128
+
129
+
130
+ @dataclass
131
+ class ModelArguments:
132
+ """
133
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
134
+ """
135
+
136
+ model_name_or_path: Optional[str] = field(
137
+ default=None,
138
+ metadata={
139
+ "help": "The model checkpoint for weights initialization."
140
+ "Don't set if you want to train a model from scratch."
141
+ },
142
+ )
143
+ model_type: Optional[str] = field(
144
+ default=None,
145
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
146
+ )
147
+ config_name: Optional[str] = field(
148
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
149
+ )
150
+ tokenizer_name: Optional[str] = field(
151
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
152
+ )
153
+ cache_dir: Optional[str] = field(
154
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
155
+ )
156
+ use_fast_tokenizer: bool = field(
157
+ default=True,
158
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
159
+ )
160
+ dtype: Optional[str] = field(
161
+ default="float32",
162
+ metadata={
163
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
164
+ },
165
+ )
166
+
167
+
168
+ @dataclass
169
+ class DataTrainingArguments:
170
+ """
171
+ Arguments pertaining to what data we are going to input our model for training and eval.
172
+ """
173
+
174
+ dataset_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
176
+ )
177
+ dataset_config_name: Optional[str] = field(
178
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
179
+ )
180
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
181
+ validation_file: Optional[str] = field(
182
+ default=None,
183
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
184
+ )
185
+ train_ref_file: Optional[str] = field(
186
+ default=None,
187
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
188
+ )
189
+ validation_ref_file: Optional[str] = field(
190
+ default=None,
191
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
192
+ )
193
+ overwrite_cache: bool = field(
194
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
195
+ )
196
+ validation_split_percentage: Optional[int] = field(
197
+ default=5,
198
+ metadata={
199
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
200
+ },
201
+ )
202
+ max_seq_length: Optional[int] = field(
203
+ default=None,
204
+ metadata={
205
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
206
+ "than this will be truncated. Default to the max input length of the model."
207
+ },
208
+ )
209
+ preprocessing_num_workers: Optional[int] = field(
210
+ default=None,
211
+ metadata={"help": "The number of processes to use for the preprocessing."},
212
+ )
213
+ mlm_probability: float = field(
214
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
215
+ )
216
+ pad_to_max_length: bool = field(
217
+ default=False,
218
+ metadata={
219
+ "help": "Whether to pad all samples to `max_seq_length`. "
220
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
221
+ },
222
+ )
223
+ line_by_line: bool = field(
224
+ default=False,
225
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
226
+ )
227
+
228
+ def __post_init__(self):
229
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
230
+ raise ValueError("Need either a dataset name or a training/validation file.")
231
+ else:
232
+ if self.train_file is not None:
233
+ extension = self.train_file.split(".")[-1]
234
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
235
+ if self.validation_file is not None:
236
+ extension = self.validation_file.split(".")[-1]
237
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
238
+
239
+
240
+ @flax.struct.dataclass
241
+ class FlaxDataCollatorForLanguageModeling:
242
+ """
243
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
244
+ are not all of the same length.
245
+
246
+ Args:
247
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
248
+ The tokenizer used for encoding the data.
249
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
250
+ The probability with which to (randomly) mask tokens in the input.
251
+
252
+ .. note::
253
+
254
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
255
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
256
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
257
+ argument :obj:`return_special_tokens_mask=True`.
258
+ """
259
+
260
+ tokenizer: PreTrainedTokenizerBase
261
+ mlm_probability: float = 0.15
262
+
263
+ def __post_init__(self):
264
+ if self.tokenizer.mask_token is None:
265
+ raise ValueError(
266
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
267
+ "You should pass `mlm=False` to train on causal language modeling instead."
268
+ )
269
+
270
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
271
+ # Handle dict or lists with proper padding and conversion to tensor.
272
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
273
+
274
+ # If special token mask has been preprocessed, pop it from the dict.
275
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
276
+
277
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
278
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
279
+ )
280
+ return batch
281
+
282
+ def mask_tokens(
283
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
284
+ ) -> Tuple[np.ndarray, np.ndarray]:
285
+ """
286
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
287
+ """
288
+ labels = inputs.copy()
289
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
290
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
291
+ special_tokens_mask = special_tokens_mask.astype("bool")
292
+
293
+ probability_matrix[special_tokens_mask] = 0.0
294
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
295
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
296
+
297
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
298
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
299
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
300
+
301
+ # 10% of the time, we replace masked input tokens with random word
302
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
303
+ indices_random &= masked_indices & ~indices_replaced
304
+
305
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
306
+ inputs[indices_random] = random_words[indices_random]
307
+
308
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
309
+ return inputs, labels
310
+
311
+
312
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
313
+ num_samples = len(samples_idx)
314
+ samples_to_remove = num_samples % batch_size
315
+
316
+ if samples_to_remove != 0:
317
+ samples_idx = samples_idx[:-samples_to_remove]
318
+ sections_split = num_samples // batch_size
319
+ batch_idx = np.split(samples_idx, sections_split)
320
+ return batch_idx
321
+
322
+
323
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
324
+ summary_writer.scalar("train_time", train_time, step)
325
+
326
+ train_metrics = get_metrics(train_metrics)
327
+ for key, vals in train_metrics.items():
328
+ tag = f"train_{key}"
329
+ for i, val in enumerate(vals):
330
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
331
+
332
+
333
+ def write_eval_metric(summary_writer, eval_metrics, step):
334
+ for metric_name, value in eval_metrics.items():
335
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
336
+
337
+
338
+ def main():
339
+ # See all possible arguments in src/transformers/training_args.py
340
+ # or by passing the --help flag to this script.
341
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
342
+
343
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
344
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
345
+ # If we pass only one argument to the script and it's the path to a json file,
346
+ # let's parse it to get our arguments.
347
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
348
+ else:
349
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
350
+
351
+ if (
352
+ os.path.exists(training_args.output_dir)
353
+ and os.listdir(training_args.output_dir)
354
+ and training_args.do_train
355
+ and not training_args.overwrite_output_dir
356
+ ):
357
+ raise ValueError(
358
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
359
+ "Use --overwrite_output_dir to overcome."
360
+ )
361
+
362
+ # Setup logging
363
+ logging.basicConfig(
364
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
365
+ level=logging.INFO,
366
+ datefmt="[%X]",
367
+ )
368
+
369
+ # Log on each process the small summary:
370
+ logger = logging.getLogger(__name__)
371
+
372
+ # Set the verbosity to info of the Transformers logger (on main process only):
373
+ logger.info(f"Training/evaluation parameters {training_args}")
374
+
375
+ # Set seed before initializing model.
376
+ set_seed(training_args.seed)
377
+
378
+ # Handle the repository creation
379
+ if training_args.push_to_hub:
380
+ if training_args.hub_model_id is None:
381
+ repo_name = get_full_repo_name(
382
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
383
+ )
384
+ else:
385
+ repo_name = training_args.hub_model_id
386
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
387
+
388
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
389
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
390
+ # (the dataset will be downloaded automatically from the datasets Hub).
391
+ #
392
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
393
+ # 'text' is found. You can easily tweak this behavior (see below).
394
+ #
395
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
396
+ # download the dataset.
397
+ if data_args.dataset_name is not None:
398
+ # Downloading and loading a dataset from the hub.
399
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
400
+
401
+ if "validation" not in datasets.keys():
402
+ datasets["validation"] = load_dataset(
403
+ data_args.dataset_name,
404
+ data_args.dataset_config_name,
405
+ split=f"train[:{data_args.validation_split_percentage}%]",
406
+ cache_dir=model_args.cache_dir,
407
+ )
408
+ datasets["train"] = load_dataset(
409
+ data_args.dataset_name,
410
+ data_args.dataset_config_name,
411
+ split=f"train[{data_args.validation_split_percentage}%:]",
412
+ cache_dir=model_args.cache_dir,
413
+ )
414
+ else:
415
+ data_files = {}
416
+ if data_args.train_file is not None:
417
+ data_files["train"] = data_args.train_file
418
+ if data_args.validation_file is not None:
419
+ data_files["validation"] = data_args.validation_file
420
+ extension = data_args.train_file.split(".")[-1]
421
+ if extension == "txt":
422
+ extension = "text"
423
+ datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
424
+
425
+ if "validation" not in datasets.keys():
426
+ datasets["validation"] = load_dataset(
427
+ extension,
428
+ data_files=data_files,
429
+ split=f"train[:{data_args.validation_split_percentage}%]",
430
+ cache_dir=model_args.cache_dir,
431
+ )
432
+ datasets["train"] = load_dataset(
433
+ extension,
434
+ data_files=data_files,
435
+ split=f"train[{data_args.validation_split_percentage}%:]",
436
+ cache_dir=model_args.cache_dir,
437
+ )
438
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
439
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
440
+
441
+ # Load pretrained model and tokenizer
442
+
443
+ # Distributed training:
444
+ # The .from_pretrained methods guarantee that only one local process can concurrently
445
+ # download model & vocab.
446
+ if model_args.config_name:
447
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
448
+ elif model_args.model_name_or_path:
449
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
450
+ else:
451
+ config = CONFIG_MAPPING[model_args.model_type]()
452
+ logger.warning("You are instantiating a new config instance from scratch.")
453
+
454
+ if model_args.tokenizer_name:
455
+ tokenizer = AutoTokenizer.from_pretrained(
456
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
457
+ )
458
+ elif model_args.model_name_or_path:
459
+ tokenizer = AutoTokenizer.from_pretrained(
460
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
461
+ )
462
+ else:
463
+ raise ValueError(
464
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
465
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
466
+ )
467
+
468
+ # Preprocessing the datasets.
469
+ # First we tokenize all the texts.
470
+ if training_args.do_train:
471
+ column_names = datasets["train"].column_names
472
+ else:
473
+ column_names = datasets["validation"].column_names
474
+ text_column_name = "text" if "text" in column_names else column_names[0]
475
+
476
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
477
+
478
+ if data_args.line_by_line:
479
+ # When using line_by_line, we just tokenize each nonempty line.
480
+ padding = "max_length" if data_args.pad_to_max_length else False
481
+
482
+ def tokenize_function(examples):
483
+ # Remove empty lines
484
+ examples = [line for line in examples if len(line) > 0 and not line.isspace()]
485
+ return tokenizer(
486
+ examples,
487
+ return_special_tokens_mask=True,
488
+ padding=padding,
489
+ truncation=True,
490
+ max_length=max_seq_length,
491
+ )
492
+
493
+ tokenized_datasets = datasets.map(
494
+ tokenize_function,
495
+ input_columns=[text_column_name],
496
+ batched=True,
497
+ num_proc=data_args.preprocessing_num_workers,
498
+ remove_columns=column_names,
499
+ load_from_cache_file=not data_args.overwrite_cache,
500
+ )
501
+
502
+ else:
503
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
504
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
505
+ # efficient when it receives the `special_tokens_mask`.
506
+ def tokenize_function(examples):
507
+ return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
508
+
509
+ tokenized_datasets = datasets.map(
510
+ tokenize_function,
511
+ batched=True,
512
+ num_proc=data_args.preprocessing_num_workers,
513
+ remove_columns=column_names,
514
+ load_from_cache_file=not data_args.overwrite_cache,
515
+ )
516
+
517
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
518
+ # max_seq_length.
519
+ def group_texts(examples):
520
+ # Concatenate all texts.
521
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
522
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
523
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
524
+ # customize this part to your needs.
525
+ if total_length >= max_seq_length:
526
+ total_length = (total_length // max_seq_length) * max_seq_length
527
+ # Split by chunks of max_len.
528
+ result = {
529
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
530
+ for k, t in concatenated_examples.items()
531
+ }
532
+ return result
533
+
534
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
535
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
536
+ # might be slower to preprocess.
537
+ #
538
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
539
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
540
+ tokenized_datasets = tokenized_datasets.map(
541
+ group_texts,
542
+ batched=True,
543
+ num_proc=data_args.preprocessing_num_workers,
544
+ load_from_cache_file=not data_args.overwrite_cache,
545
+ )
546
+
547
+ # Enable tensorboard only on the master node
548
+ has_tensorboard = is_tensorboard_available()
549
+ if has_tensorboard and jax.process_index() == 0:
550
+ try:
551
+ # Enable Weight&Biases
552
+ import wandb
553
+ wandb.init(
554
+ entity='versae',
555
+ project='roberta-base-ncc',
556
+ sync_tensorboard=False,
557
+ )
558
+ wandb.config.update(training_args)
559
+ wandb.config.update(model_args)
560
+ wandb.config.update(data_args)
561
+
562
+ from flax.metrics.tensorboard import SummaryWriter
563
+
564
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
565
+ except ImportError as ie:
566
+ has_tensorboard = False
567
+ logger.warning(
568
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
569
+ )
570
+ else:
571
+ logger.warning(
572
+ "Unable to display metrics through TensorBoard because the package is not installed: "
573
+ "Please run pip install tensorboard to enable."
574
+ )
575
+
576
+ # Data collator
577
+ # This one will take care of randomly masking the tokens.
578
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
579
+
580
+ # Initialize our training
581
+ rng = jax.random.PRNGKey(training_args.seed)
582
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
583
+
584
+ if model_args.model_name_or_path:
585
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
586
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
587
+ )
588
+ else:
589
+ model = FlaxAutoModelForMaskedLM.from_config(
590
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
591
+ )
592
+
593
+ # Store some constant
594
+ num_epochs = int(training_args.num_train_epochs)
595
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
596
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
597
+
598
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
599
+
600
+ # Create learning rate schedule
601
+ warmup_fn = optax.linear_schedule(
602
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
603
+ )
604
+ decay_fn = optax.linear_schedule(
605
+ init_value=training_args.learning_rate,
606
+ end_value=0,
607
+ transition_steps=num_train_steps - training_args.warmup_steps,
608
+ )
609
+ linear_decay_lr_schedule_fn = optax.join_schedules(
610
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
611
+ )
612
+
613
+ # We use Optax's "masking" functionality to not apply weight decay
614
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
615
+ # mask boolean with the same structure as the parameters.
616
+ # The mask is True for parameters that should be decayed.
617
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
618
+ # For other models, one should correct the layer norm parameter naming
619
+ # accordingly.
620
+ def decay_mask_fn(params):
621
+ flat_params = traverse_util.flatten_dict(params)
622
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
623
+ return traverse_util.unflatten_dict(flat_mask)
624
+
625
+ # create adam optimizer
626
+ if training_args.adafactor:
627
+ # We use the default parameters here to initialize adafactor,
628
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
629
+ optimizer = optax.adafactor(
630
+ learning_rate=linear_decay_lr_schedule_fn,
631
+ )
632
+ else:
633
+ optimizer = optax.adamw(
634
+ learning_rate=linear_decay_lr_schedule_fn,
635
+ b1=training_args.adam_beta1,
636
+ b2=training_args.adam_beta2,
637
+ eps=training_args.adam_epsilon,
638
+ weight_decay=training_args.weight_decay,
639
+ mask=decay_mask_fn,
640
+ )
641
+
642
+ # Setup train state
643
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
644
+
645
+ # Define gradient update step fn
646
+ def train_step(state, batch, dropout_rng):
647
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
648
+
649
+ def loss_fn(params):
650
+ labels = batch.pop("labels")
651
+
652
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
653
+
654
+ # compute loss, ignore padded input tokens
655
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
656
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
657
+
658
+ # take average
659
+ loss = loss.sum() / label_mask.sum()
660
+
661
+ return loss
662
+
663
+ grad_fn = jax.value_and_grad(loss_fn)
664
+ loss, grad = grad_fn(state.params)
665
+ grad = jax.lax.pmean(grad, "batch")
666
+ new_state = state.apply_gradients(grads=grad)
667
+
668
+ metrics = jax.lax.pmean(
669
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
670
+ )
671
+
672
+ return new_state, metrics, new_dropout_rng
673
+
674
+ # Create parallel version of the train step
675
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
676
+
677
+ # Define eval fn
678
+ def eval_step(params, batch):
679
+ labels = batch.pop("labels")
680
+
681
+ logits = model(**batch, params=params, train=False)[0]
682
+
683
+ # compute loss, ignore padded input tokens
684
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
685
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
686
+
687
+ # compute accuracy
688
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
689
+
690
+ # summarize metrics
691
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
692
+ metrics = jax.lax.psum(metrics, axis_name="batch")
693
+
694
+ return metrics
695
+
696
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
697
+
698
+ # Replicate the train state on each device
699
+ state = jax_utils.replicate(state)
700
+
701
+ train_time = 0
702
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
703
+ for epoch in epochs:
704
+ # ======================== Training ================================
705
+ train_start = time.time()
706
+ train_metrics = []
707
+
708
+ # Create sampling rng
709
+ rng, input_rng = jax.random.split(rng)
710
+
711
+ # Generate an epoch by shuffling sampling indices from the train dataset
712
+ num_train_samples = len(tokenized_datasets["train"])
713
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
714
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
715
+
716
+ # Gather the indexes for creating the batch and do a training step
717
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
718
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
719
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
720
+
721
+ # Model forward
722
+ model_inputs = shard(model_inputs.data)
723
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
724
+ train_metrics.append(train_metric)
725
+
726
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
727
+
728
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
729
+ # Save metrics
730
+ train_metric = jax_utils.unreplicate(train_metric)
731
+ train_time += time.time() - train_start
732
+ if has_tensorboard and jax.process_index() == 0:
733
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
734
+
735
+ epochs.write(
736
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
737
+ )
738
+
739
+ train_metrics = []
740
+
741
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
742
+ # ======================== Evaluating ==============================
743
+ num_eval_samples = len(tokenized_datasets["validation"])
744
+ eval_samples_idx = jnp.arange(num_eval_samples)
745
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
746
+
747
+ eval_metrics = []
748
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
749
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
750
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
751
+
752
+ # Model forward
753
+ model_inputs = shard(model_inputs.data)
754
+ metrics = p_eval_step(state.params, model_inputs)
755
+ eval_metrics.append(metrics)
756
+
757
+ # normalize eval metrics
758
+ eval_metrics = get_metrics(eval_metrics)
759
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
760
+ eval_normalizer = eval_metrics.pop("normalizer")
761
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
762
+
763
+ # Update progress bar
764
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
765
+
766
+ # Save metrics
767
+ if has_tensorboard and jax.process_index() == 0:
768
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
769
+
770
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
771
+ # save checkpoint after each epoch and push checkpoint to the hub
772
+ if jax.process_index() == 0:
773
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
774
+ model.save_pretrained(training_args.output_dir, params=params)
775
+ tokenizer.save_pretrained(training_args.output_dir)
776
+ if training_args.push_to_hub:
777
+ repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
778
+
779
+ # Eval after training
780
+ if training_args.do_eval:
781
+ num_eval_samples = len(tokenized_datasets["validation"])
782
+ eval_samples_idx = jnp.arange(num_eval_samples)
783
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
784
+
785
+ eval_metrics = []
786
+ for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
787
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
788
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
789
+
790
+ # Model forward
791
+ model_inputs = shard(model_inputs.data)
792
+ metrics = p_eval_step(state.params, model_inputs)
793
+ eval_metrics.append(metrics)
794
+
795
+ # normalize eval metrics
796
+ eval_metrics = get_metrics(eval_metrics)
797
+ eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
798
+ eval_normalizer = eval_metrics.pop("normalizer")
799
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
800
+
801
+ try:
802
+ perplexity = math.exp(eval_metrics["loss"])
803
+ except OverflowError:
804
+ perplexity = float("inf")
805
+ eval_metrics["perplexity"] = perplexity
806
+
807
+ if jax.process_index() == 0:
808
+ eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
809
+ path = os.path.join(training_args.output_dir, "eval_results.json")
810
+ with open(path, "w") as f:
811
+ json.dump(eval_metrics, f, indent=4, sort_keys=True)
812
+
813
+
814
+ if __name__ == "__main__":
815
+ main()
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "add_prefix_space": false, "errors": "replace", "sep_token": "</s>", "cls_token": "<s>", "pad_token": "<pad>", "mask_token": "<mask>", "trim_offsets": true, "special_tokens_map_file": null, "name_or_path": "./", "tokenizer_class": "RobertaTokenizer"}
train.128.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python run_mlm_flax.py \
2
+ --output_dir="./" \
3
+ --model_type="roberta" \
4
+ --config_name="roberta-base" \
5
+ --tokenizer_name="NbAiLab/nb-roberta-base" \
6
+ --dataset_name="NbAiLab/NCC" \
7
+ --max_seq_length="128" \
8
+ --weight_decay="0.01" \
9
+ --per_device_train_batch_size="232" \
10
+ --per_device_eval_batch_size="232" \
11
+ --pad_to_max_length \
12
+ --learning_rate="6e-4" \
13
+ --warmup_steps="10000" \
14
+ --overwrite_output_dir \
15
+ --num_train_epochs="3" \
16
+ --adam_beta1="0.9" \
17
+ --adam_beta2="0.98" \
18
+ --adam_epsilon="1e-6" \
19
+ --logging_steps="1000" \
20
+ --save_steps="1000" \
21
+ --eval_steps="1000" \
22
+ --do_train \
23
+ --do_eval \
24
+ --dtype="bfloat16" \
25
+ --push_to_hub
train.512.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python run_mlm_flax.py \
2
+ --output_dir="./" \
3
+ --model_type="roberta" \
4
+ --model_name_or_path="./" \
5
+ --config_name="./" \
6
+ --tokenizer_name="./" \
7
+ --dataset_name="NbAiLab/NCC" \
8
+ --max_seq_length="512" \
9
+ --weight_decay="0.01" \
10
+ --per_device_train_batch_size="46" \
11
+ --per_device_eval_batch_size="46" \
12
+ --pad_to_max_length \
13
+ --learning_rate="6e-4" \
14
+ --warmup_steps="1000" \
15
+ --overwrite_output_dir \
16
+ --num_train_epochs="3" \
17
+ --adam_beta1="0.9" \
18
+ --adam_beta2="0.98" \
19
+ --adam_epsilon="1e-6" \
20
+ --logging_steps="1000" \
21
+ --save_steps="1000" \
22
+ --eval_steps="1000" \
23
+ --do_train \
24
+ --do_eval \
25
+ --dtype="bfloat16" \
26
+ --push_to_hub
vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
wandb/debug-internal.log ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20220119_161158-274aad95/logs/debug-internal.log
wandb/debug.log ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20220119_161158-274aad95/logs/debug.log
wandb/latest-run ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20220119_161158-274aad95
wandb/run-20220114_212855-32qdb4k5/files/code/run_mlm_flax.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=fill-mask
22
+ """
23
+ import json
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ import time
29
+ from dataclasses import asdict, dataclass, field
30
+ from enum import Enum
31
+ from itertools import chain
32
+
33
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
34
+ from pathlib import Path
35
+ from typing import Dict, List, Optional, Tuple
36
+
37
+ import numpy as np
38
+ from datasets import load_dataset
39
+ from tqdm import tqdm
40
+
41
+ import flax
42
+ import jax
43
+ import jax.numpy as jnp
44
+ import optax
45
+ from flax import jax_utils, traverse_util
46
+ from flax.training import train_state
47
+ from flax.training.common_utils import get_metrics, onehot, shard
48
+ from huggingface_hub import Repository
49
+ from transformers import (
50
+ CONFIG_MAPPING,
51
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
52
+ AutoConfig,
53
+ AutoTokenizer,
54
+ FlaxAutoModelForMaskedLM,
55
+ HfArgumentParser,
56
+ PreTrainedTokenizerBase,
57
+ TensorType,
58
+ is_tensorboard_available,
59
+ set_seed,
60
+ )
61
+ from transformers.file_utils import get_full_repo_name
62
+
63
+
64
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
65
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
66
+
67
+
68
+ @dataclass
69
+ class TrainingArguments:
70
+ output_dir: str = field(
71
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
72
+ )
73
+ overwrite_output_dir: bool = field(
74
+ default=False,
75
+ metadata={
76
+ "help": (
77
+ "Overwrite the content of the output directory. "
78
+ "Use this to continue training if output_dir points to a checkpoint directory."
79
+ )
80
+ },
81
+ )
82
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
83
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
84
+ per_device_train_batch_size: int = field(
85
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
86
+ )
87
+ per_device_eval_batch_size: int = field(
88
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
89
+ )
90
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
91
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
92
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
93
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
94
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
95
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
96
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
97
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
98
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
99
+ save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
100
+ eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
101
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
102
+ push_to_hub: bool = field(
103
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
104
+ )
105
+ hub_model_id: str = field(
106
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
107
+ )
108
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
109
+
110
+ def __post_init__(self):
111
+ if self.output_dir is not None:
112
+ self.output_dir = os.path.expanduser(self.output_dir)
113
+
114
+ def to_dict(self):
115
+ """
116
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
117
+ the token values by removing their value.
118
+ """
119
+ d = asdict(self)
120
+ for k, v in d.items():
121
+ if isinstance(v, Enum):
122
+ d[k] = v.value
123
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
124
+ d[k] = [x.value for x in v]
125
+ if k.endswith("_token"):
126
+ d[k] = f"<{k.upper()}>"
127
+ return d
128
+
129
+
130
+ @dataclass
131
+ class ModelArguments:
132
+ """
133
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
134
+ """
135
+
136
+ model_name_or_path: Optional[str] = field(
137
+ default=None,
138
+ metadata={
139
+ "help": "The model checkpoint for weights initialization."
140
+ "Don't set if you want to train a model from scratch."
141
+ },
142
+ )
143
+ model_type: Optional[str] = field(
144
+ default=None,
145
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
146
+ )
147
+ config_name: Optional[str] = field(
148
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
149
+ )
150
+ tokenizer_name: Optional[str] = field(
151
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
152
+ )
153
+ cache_dir: Optional[str] = field(
154
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
155
+ )
156
+ use_fast_tokenizer: bool = field(
157
+ default=True,
158
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
159
+ )
160
+ dtype: Optional[str] = field(
161
+ default="float32",
162
+ metadata={
163
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
164
+ },
165
+ )
166
+
167
+
168
+ @dataclass
169
+ class DataTrainingArguments:
170
+ """
171
+ Arguments pertaining to what data we are going to input our model for training and eval.
172
+ """
173
+
174
+ dataset_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
176
+ )
177
+ dataset_config_name: Optional[str] = field(
178
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
179
+ )
180
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
181
+ validation_file: Optional[str] = field(
182
+ default=None,
183
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
184
+ )
185
+ train_ref_file: Optional[str] = field(
186
+ default=None,
187
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
188
+ )
189
+ validation_ref_file: Optional[str] = field(
190
+ default=None,
191
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
192
+ )
193
+ overwrite_cache: bool = field(
194
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
195
+ )
196
+ validation_split_percentage: Optional[int] = field(
197
+ default=5,
198
+ metadata={
199
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
200
+ },
201
+ )
202
+ max_seq_length: Optional[int] = field(
203
+ default=None,
204
+ metadata={
205
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
206
+ "than this will be truncated. Default to the max input length of the model."
207
+ },
208
+ )
209
+ preprocessing_num_workers: Optional[int] = field(
210
+ default=None,
211
+ metadata={"help": "The number of processes to use for the preprocessing."},
212
+ )
213
+ mlm_probability: float = field(
214
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
215
+ )
216
+ pad_to_max_length: bool = field(
217
+ default=False,
218
+ metadata={
219
+ "help": "Whether to pad all samples to `max_seq_length`. "
220
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
221
+ },
222
+ )
223
+ line_by_line: bool = field(
224
+ default=False,
225
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
226
+ )
227
+
228
+ def __post_init__(self):
229
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
230
+ raise ValueError("Need either a dataset name or a training/validation file.")
231
+ else:
232
+ if self.train_file is not None:
233
+ extension = self.train_file.split(".")[-1]
234
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
235
+ if self.validation_file is not None:
236
+ extension = self.validation_file.split(".")[-1]
237
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
238
+
239
+
240
+ @flax.struct.dataclass
241
+ class FlaxDataCollatorForLanguageModeling:
242
+ """
243
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
244
+ are not all of the same length.
245
+
246
+ Args:
247
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
248
+ The tokenizer used for encoding the data.
249
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
250
+ The probability with which to (randomly) mask tokens in the input.
251
+
252
+ .. note::
253
+
254
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
255
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
256
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
257
+ argument :obj:`return_special_tokens_mask=True`.
258
+ """
259
+
260
+ tokenizer: PreTrainedTokenizerBase
261
+ mlm_probability: float = 0.15
262
+
263
+ def __post_init__(self):
264
+ if self.tokenizer.mask_token is None:
265
+ raise ValueError(
266
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
267
+ "You should pass `mlm=False` to train on causal language modeling instead."
268
+ )
269
+
270
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
271
+ # Handle dict or lists with proper padding and conversion to tensor.
272
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
273
+
274
+ # If special token mask has been preprocessed, pop it from the dict.
275
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
276
+
277
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
278
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
279
+ )
280
+ return batch
281
+
282
+ def mask_tokens(
283
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
284
+ ) -> Tuple[np.ndarray, np.ndarray]:
285
+ """
286
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
287
+ """
288
+ labels = inputs.copy()
289
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
290
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
291
+ special_tokens_mask = special_tokens_mask.astype("bool")
292
+
293
+ probability_matrix[special_tokens_mask] = 0.0
294
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
295
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
296
+
297
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
298
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
299
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
300
+
301
+ # 10% of the time, we replace masked input tokens with random word
302
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
303
+ indices_random &= masked_indices & ~indices_replaced
304
+
305
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
306
+ inputs[indices_random] = random_words[indices_random]
307
+
308
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
309
+ return inputs, labels
310
+
311
+
312
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
313
+ num_samples = len(samples_idx)
314
+ samples_to_remove = num_samples % batch_size
315
+
316
+ if samples_to_remove != 0:
317
+ samples_idx = samples_idx[:-samples_to_remove]
318
+ sections_split = num_samples // batch_size
319
+ batch_idx = np.split(samples_idx, sections_split)
320
+ return batch_idx
321
+
322
+
323
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
324
+ summary_writer.scalar("train_time", train_time, step)
325
+
326
+ train_metrics = get_metrics(train_metrics)
327
+ for key, vals in train_metrics.items():
328
+ tag = f"train_{key}"
329
+ for i, val in enumerate(vals):
330
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
331
+
332
+
333
+ def write_eval_metric(summary_writer, eval_metrics, step):
334
+ for metric_name, value in eval_metrics.items():
335
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
336
+
337
+
338
+ def main():
339
+ # See all possible arguments in src/transformers/training_args.py
340
+ # or by passing the --help flag to this script.
341
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
342
+
343
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
344
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
345
+ # If we pass only one argument to the script and it's the path to a json file,
346
+ # let's parse it to get our arguments.
347
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
348
+ else:
349
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
350
+
351
+ if (
352
+ os.path.exists(training_args.output_dir)
353
+ and os.listdir(training_args.output_dir)
354
+ and training_args.do_train
355
+ and not training_args.overwrite_output_dir
356
+ ):
357
+ raise ValueError(
358
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
359
+ "Use --overwrite_output_dir to overcome."
360
+ )
361
+
362
+ # Setup logging
363
+ logging.basicConfig(
364
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
365
+ level=logging.INFO,
366
+ datefmt="[%X]",
367
+ )
368
+
369
+ # Log on each process the small summary:
370
+ logger = logging.getLogger(__name__)
371
+
372
+ # Set the verbosity to info of the Transformers logger (on main process only):
373
+ logger.info(f"Training/evaluation parameters {training_args}")
374
+
375
+ # Set seed before initializing model.
376
+ set_seed(training_args.seed)
377
+
378
+ # Handle the repository creation
379
+ if training_args.push_to_hub:
380
+ if training_args.hub_model_id is None:
381
+ repo_name = get_full_repo_name(
382
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
383
+ )
384
+ else:
385
+ repo_name = training_args.hub_model_id
386
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
387
+
388
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
389
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
390
+ # (the dataset will be downloaded automatically from the datasets Hub).
391
+ #
392
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
393
+ # 'text' is found. You can easily tweak this behavior (see below).
394
+ #
395
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
396
+ # download the dataset.
397
+ if data_args.dataset_name is not None:
398
+ # Downloading and loading a dataset from the hub.
399
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
400
+
401
+ if "validation" not in datasets.keys():
402
+ datasets["validation"] = load_dataset(
403
+ data_args.dataset_name,
404
+ data_args.dataset_config_name,
405
+ split=f"train[:{data_args.validation_split_percentage}%]",
406
+ cache_dir=model_args.cache_dir,
407
+ )
408
+ datasets["train"] = load_dataset(
409
+ data_args.dataset_name,
410
+ data_args.dataset_config_name,
411
+ split=f"train[{data_args.validation_split_percentage}%:]",
412
+ cache_dir=model_args.cache_dir,
413
+ )
414
+ else:
415
+ data_files = {}
416
+ if data_args.train_file is not None:
417
+ data_files["train"] = data_args.train_file
418
+ if data_args.validation_file is not None:
419
+ data_files["validation"] = data_args.validation_file
420
+ extension = data_args.train_file.split(".")[-1]
421
+ if extension == "txt":
422
+ extension = "text"
423
+ datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
424
+
425
+ if "validation" not in datasets.keys():
426
+ datasets["validation"] = load_dataset(
427
+ extension,
428
+ data_files=data_files,
429
+ split=f"train[:{data_args.validation_split_percentage}%]",
430
+ cache_dir=model_args.cache_dir,
431
+ )
432
+ datasets["train"] = load_dataset(
433
+ extension,
434
+ data_files=data_files,
435
+ split=f"train[{data_args.validation_split_percentage}%:]",
436
+ cache_dir=model_args.cache_dir,
437
+ )
438
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
439
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
440
+
441
+ # Load pretrained model and tokenizer
442
+
443
+ # Distributed training:
444
+ # The .from_pretrained methods guarantee that only one local process can concurrently
445
+ # download model & vocab.
446
+ if model_args.config_name:
447
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
448
+ elif model_args.model_name_or_path:
449
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
450
+ else:
451
+ config = CONFIG_MAPPING[model_args.model_type]()
452
+ logger.warning("You are instantiating a new config instance from scratch.")
453
+
454
+ if model_args.tokenizer_name:
455
+ tokenizer = AutoTokenizer.from_pretrained(
456
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
457
+ )
458
+ elif model_args.model_name_or_path:
459
+ tokenizer = AutoTokenizer.from_pretrained(
460
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
461
+ )
462
+ else:
463
+ raise ValueError(
464
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
465
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
466
+ )
467
+
468
+ # Preprocessing the datasets.
469
+ # First we tokenize all the texts.
470
+ if training_args.do_train:
471
+ column_names = datasets["train"].column_names
472
+ else:
473
+ column_names = datasets["validation"].column_names
474
+ text_column_name = "text" if "text" in column_names else column_names[0]
475
+
476
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
477
+
478
+ if data_args.line_by_line:
479
+ # When using line_by_line, we just tokenize each nonempty line.
480
+ padding = "max_length" if data_args.pad_to_max_length else False
481
+
482
+ def tokenize_function(examples):
483
+ # Remove empty lines
484
+ examples = [line for line in examples if len(line) > 0 and not line.isspace()]
485
+ return tokenizer(
486
+ examples,
487
+ return_special_tokens_mask=True,
488
+ padding=padding,
489
+ truncation=True,
490
+ max_length=max_seq_length,
491
+ )
492
+
493
+ tokenized_datasets = datasets.map(
494
+ tokenize_function,
495
+ input_columns=[text_column_name],
496
+ batched=True,
497
+ num_proc=data_args.preprocessing_num_workers,
498
+ remove_columns=column_names,
499
+ load_from_cache_file=not data_args.overwrite_cache,
500
+ )
501
+
502
+ else:
503
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
504
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
505
+ # efficient when it receives the `special_tokens_mask`.
506
+ def tokenize_function(examples):
507
+ return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
508
+
509
+ tokenized_datasets = datasets.map(
510
+ tokenize_function,
511
+ batched=True,
512
+ num_proc=data_args.preprocessing_num_workers,
513
+ remove_columns=column_names,
514
+ load_from_cache_file=not data_args.overwrite_cache,
515
+ )
516
+
517
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
518
+ # max_seq_length.
519
+ def group_texts(examples):
520
+ # Concatenate all texts.
521
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
522
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
523
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
524
+ # customize this part to your needs.
525
+ if total_length >= max_seq_length:
526
+ total_length = (total_length // max_seq_length) * max_seq_length
527
+ # Split by chunks of max_len.
528
+ result = {
529
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
530
+ for k, t in concatenated_examples.items()
531
+ }
532
+ return result
533
+
534
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
535
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
536
+ # might be slower to preprocess.
537
+ #
538
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
539
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
540
+ tokenized_datasets = tokenized_datasets.map(
541
+ group_texts,
542
+ batched=True,
543
+ num_proc=data_args.preprocessing_num_workers,
544
+ load_from_cache_file=not data_args.overwrite_cache,
545
+ )
546
+
547
+ # Enable tensorboard only on the master node
548
+ has_tensorboard = is_tensorboard_available()
549
+ if has_tensorboard and jax.process_index() == 0:
550
+ try:
551
+ # Enable Weight&Biases
552
+ import wandb
553
+ wandb.init(
554
+ entity='versae',
555
+ project='roberta-base-ncc',
556
+ sync_tensorboard=False,
557
+ )
558
+ wandb.config.update(training_args)
559
+ wandb.config.update(model_args)
560
+ wandb.config.update(data_args)
561
+
562
+ from flax.metrics.tensorboard import SummaryWriter
563
+
564
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
565
+ except ImportError as ie:
566
+ has_tensorboard = False
567
+ logger.warning(
568
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
569
+ )
570
+ else:
571
+ logger.warning(
572
+ "Unable to display metrics through TensorBoard because the package is not installed: "
573
+ "Please run pip install tensorboard to enable."
574
+ )
575
+
576
+ # Data collator
577
+ # This one will take care of randomly masking the tokens.
578
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
579
+
580
+ # Initialize our training
581
+ rng = jax.random.PRNGKey(training_args.seed)
582
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
583
+
584
+ if model_args.model_name_or_path:
585
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
586
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
587
+ )
588
+ else:
589
+ model = FlaxAutoModelForMaskedLM.from_config(
590
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
591
+ )
592
+
593
+ # Store some constant
594
+ num_epochs = int(training_args.num_train_epochs)
595
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
596
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
597
+
598
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
599
+
600
+ # Create learning rate schedule
601
+ warmup_fn = optax.linear_schedule(
602
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
603
+ )
604
+ decay_fn = optax.linear_schedule(
605
+ init_value=training_args.learning_rate,
606
+ end_value=0,
607
+ transition_steps=num_train_steps - training_args.warmup_steps,
608
+ )
609
+ linear_decay_lr_schedule_fn = optax.join_schedules(
610
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
611
+ )
612
+
613
+ # We use Optax's "masking" functionality to not apply weight decay
614
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
615
+ # mask boolean with the same structure as the parameters.
616
+ # The mask is True for parameters that should be decayed.
617
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
618
+ # For other models, one should correct the layer norm parameter naming
619
+ # accordingly.
620
+ def decay_mask_fn(params):
621
+ flat_params = traverse_util.flatten_dict(params)
622
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
623
+ return traverse_util.unflatten_dict(flat_mask)
624
+
625
+ # create adam optimizer
626
+ if training_args.adafactor:
627
+ # We use the default parameters here to initialize adafactor,
628
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
629
+ optimizer = optax.adafactor(
630
+ learning_rate=linear_decay_lr_schedule_fn,
631
+ )
632
+ else:
633
+ optimizer = optax.adamw(
634
+ learning_rate=linear_decay_lr_schedule_fn,
635
+ b1=training_args.adam_beta1,
636
+ b2=training_args.adam_beta2,
637
+ eps=training_args.adam_epsilon,
638
+ weight_decay=training_args.weight_decay,
639
+ mask=decay_mask_fn,
640
+ )
641
+
642
+ # Setup train state
643
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
644
+
645
+ # Define gradient update step fn
646
+ def train_step(state, batch, dropout_rng):
647
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
648
+
649
+ def loss_fn(params):
650
+ labels = batch.pop("labels")
651
+
652
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
653
+
654
+ # compute loss, ignore padded input tokens
655
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
656
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
657
+
658
+ # take average
659
+ loss = loss.sum() / label_mask.sum()
660
+
661
+ return loss
662
+
663
+ grad_fn = jax.value_and_grad(loss_fn)
664
+ loss, grad = grad_fn(state.params)
665
+ grad = jax.lax.pmean(grad, "batch")
666
+ new_state = state.apply_gradients(grads=grad)
667
+
668
+ metrics = jax.lax.pmean(
669
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
670
+ )
671
+
672
+ return new_state, metrics, new_dropout_rng
673
+
674
+ # Create parallel version of the train step
675
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
676
+
677
+ # Define eval fn
678
+ def eval_step(params, batch):
679
+ labels = batch.pop("labels")
680
+
681
+ logits = model(**batch, params=params, train=False)[0]
682
+
683
+ # compute loss, ignore padded input tokens
684
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
685
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
686
+
687
+ # compute accuracy
688
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
689
+
690
+ # summarize metrics
691
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
692
+ metrics = jax.lax.psum(metrics, axis_name="batch")
693
+
694
+ return metrics
695
+
696
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
697
+
698
+ # Replicate the train state on each device
699
+ state = jax_utils.replicate(state)
700
+
701
+ train_time = 0
702
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
703
+ for epoch in epochs:
704
+ # ======================== Training ================================
705
+ train_start = time.time()
706
+ train_metrics = []
707
+
708
+ # Create sampling rng
709
+ rng, input_rng = jax.random.split(rng)
710
+
711
+ # Generate an epoch by shuffling sampling indices from the train dataset
712
+ num_train_samples = len(tokenized_datasets["train"])
713
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
714
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
715
+
716
+ # Gather the indexes for creating the batch and do a training step
717
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
718
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
719
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
720
+
721
+ # Model forward
722
+ model_inputs = shard(model_inputs.data)
723
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
724
+ train_metrics.append(train_metric)
725
+
726
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
727
+
728
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
729
+ # Save metrics
730
+ train_metric = jax_utils.unreplicate(train_metric)
731
+ train_time += time.time() - train_start
732
+ if has_tensorboard and jax.process_index() == 0:
733
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
734
+
735
+ epochs.write(
736
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
737
+ )
738
+
739
+ train_metrics = []
740
+
741
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
742
+ # ======================== Evaluating ==============================
743
+ num_eval_samples = len(tokenized_datasets["validation"])
744
+ eval_samples_idx = jnp.arange(num_eval_samples)
745
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
746
+
747
+ eval_metrics = []
748
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
749
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
750
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
751
+
752
+ # Model forward
753
+ model_inputs = shard(model_inputs.data)
754
+ metrics = p_eval_step(state.params, model_inputs)
755
+ eval_metrics.append(metrics)
756
+
757
+ # normalize eval metrics
758
+ eval_metrics = get_metrics(eval_metrics)
759
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
760
+ eval_normalizer = eval_metrics.pop("normalizer")
761
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
762
+
763
+ # Update progress bar
764
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
765
+
766
+ # Save metrics
767
+ if has_tensorboard and jax.process_index() == 0:
768
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
769
+
770
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
771
+ # save checkpoint after each epoch and push checkpoint to the hub
772
+ if jax.process_index() == 0:
773
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
774
+ model.save_pretrained(training_args.output_dir, params=params)
775
+ tokenizer.save_pretrained(training_args.output_dir)
776
+ if training_args.push_to_hub:
777
+ repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
778
+
779
+ # Eval after training
780
+ if training_args.do_eval:
781
+ num_eval_samples = len(tokenized_datasets["validation"])
782
+ eval_samples_idx = jnp.arange(num_eval_samples)
783
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
784
+
785
+ eval_metrics = []
786
+ for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
787
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
788
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
789
+
790
+ # Model forward
791
+ model_inputs = shard(model_inputs.data)
792
+ metrics = p_eval_step(state.params, model_inputs)
793
+ eval_metrics.append(metrics)
794
+
795
+ # normalize eval metrics
796
+ eval_metrics = get_metrics(eval_metrics)
797
+ eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
798
+ eval_normalizer = eval_metrics.pop("normalizer")
799
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
800
+
801
+ try:
802
+ perplexity = math.exp(eval_metrics["loss"])
803
+ except OverflowError:
804
+ perplexity = float("inf")
805
+ eval_metrics["perplexity"] = perplexity
806
+
807
+ if jax.process_index() == 0:
808
+ eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
809
+ path = os.path.join(training_args.output_dir, "eval_results.json")
810
+ with open(path, "w") as f:
811
+ json.dump(eval_metrics, f, indent=4, sort_keys=True)
812
+
813
+
814
+ if __name__ == "__main__":
815
+ main()
wandb/run-20220114_212855-32qdb4k5/files/config.yaml ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.12.9
7
+ code_path: code/run_mlm_flax.py
8
+ framework: huggingface
9
+ huggingface_version: 4.16.0.dev0
10
+ is_jupyter_run: false
11
+ is_kaggle_kernel: false
12
+ python_version: 3.8.10
13
+ start_time: 1642195735
14
+ t:
15
+ 1:
16
+ - 2
17
+ - 3
18
+ - 11
19
+ - 12
20
+ 2:
21
+ - 2
22
+ - 3
23
+ - 11
24
+ - 12
25
+ 4: 3.8.10
26
+ 5: 0.12.9
27
+ 6: 4.16.0.dev0
28
+ 8:
29
+ - 5
30
+ adafactor:
31
+ desc: null
32
+ value: false
33
+ adam_beta1:
34
+ desc: null
35
+ value: 0.9
36
+ adam_beta2:
37
+ desc: null
38
+ value: 0.98
39
+ adam_epsilon:
40
+ desc: null
41
+ value: 1.0e-06
42
+ cache_dir:
43
+ desc: null
44
+ value: null
45
+ config_name:
46
+ desc: null
47
+ value: roberta-base
48
+ dataset_config_name:
49
+ desc: null
50
+ value: null
51
+ dataset_name:
52
+ desc: null
53
+ value: NbAiLab/NCC
54
+ do_eval:
55
+ desc: null
56
+ value: true
57
+ do_train:
58
+ desc: null
59
+ value: true
60
+ dtype:
61
+ desc: null
62
+ value: bfloat16
63
+ eval_steps:
64
+ desc: null
65
+ value: 1000
66
+ hub_model_id:
67
+ desc: null
68
+ value: null
69
+ hub_token:
70
+ desc: null
71
+ value: null
72
+ learning_rate:
73
+ desc: null
74
+ value: 0.0006
75
+ line_by_line:
76
+ desc: null
77
+ value: false
78
+ logging_steps:
79
+ desc: null
80
+ value: 1000
81
+ max_seq_length:
82
+ desc: null
83
+ value: 128
84
+ mlm_probability:
85
+ desc: null
86
+ value: 0.15
87
+ model_name_or_path:
88
+ desc: null
89
+ value: null
90
+ model_type:
91
+ desc: null
92
+ value: roberta
93
+ num_train_epochs:
94
+ desc: null
95
+ value: 3.0
96
+ output_dir:
97
+ desc: null
98
+ value: ./
99
+ overwrite_cache:
100
+ desc: null
101
+ value: false
102
+ overwrite_output_dir:
103
+ desc: null
104
+ value: true
105
+ pad_to_max_length:
106
+ desc: null
107
+ value: true
108
+ per_device_eval_batch_size:
109
+ desc: null
110
+ value: 250
111
+ per_device_train_batch_size:
112
+ desc: null
113
+ value: 250
114
+ preprocessing_num_workers:
115
+ desc: null
116
+ value: null
117
+ push_to_hub:
118
+ desc: null
119
+ value: true
120
+ save_steps:
121
+ desc: null
122
+ value: 1000
123
+ seed:
124
+ desc: null
125
+ value: 42
126
+ tokenizer_name:
127
+ desc: null
128
+ value: NbAiLab/nb-roberta-base
129
+ train_file:
130
+ desc: null
131
+ value: null
132
+ train_ref_file:
133
+ desc: null
134
+ value: null
135
+ use_fast_tokenizer:
136
+ desc: null
137
+ value: true
138
+ validation_file:
139
+ desc: null
140
+ value: null
141
+ validation_ref_file:
142
+ desc: null
143
+ value: null
144
+ validation_split_percentage:
145
+ desc: null
146
+ value: 5
147
+ warmup_steps:
148
+ desc: null
149
+ value: 10000
150
+ weight_decay:
151
+ desc: null
152
+ value: 0.01
wandb/run-20220114_212855-32qdb4k5/files/diff.patch ADDED
File without changes
wandb/run-20220114_212855-32qdb4k5/files/output.log ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2022-01-14 21:29:01.798913: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2
+ 2022-01-14 21:29:01.798960: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
3
+ Epoch ... (1/3): 0%| | 0/3 [00:00<?, ?it/s]
4
+ Training...: 0%| | 0/39919 [02:17<?, ?it/s]
5
+ Epoch ... (1/3): 0%| | 0/3 [03:05<?, ?it/s]
6
+ Traceback (most recent call last):
7
+ File "run_mlm_flax.py", line 815, in <module>
8
+ main()
9
+ File "run_mlm_flax.py", line 723, in main
10
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
11
+ File "/data/flax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
12
+ return fun(*args, **kwargs)
13
+ File "/data/flax/lib/python3.8/site-packages/jax/_src/api.py", line 2058, in cache_miss
14
+ out_tree, out_flat = f_pmapped_(*args, **kwargs)
15
+ File "/data/flax/lib/python3.8/site-packages/jax/_src/api.py", line 1934, in f_pmapped
16
+ out = pxla.xla_pmap(
17
+ File "/data/flax/lib/python3.8/site-packages/jax/core.py", line 1727, in bind
18
+ return call_bind(self, fun, *args, **params)
19
+ File "/data/flax/lib/python3.8/site-packages/jax/core.py", line 1652, in call_bind
20
+ outs = primitive.process(top_trace, fun, tracers, params)
21
+ File "/data/flax/lib/python3.8/site-packages/jax/core.py", line 1730, in process
22
+ return trace.process_map(self, fun, tracers, params)
23
+ File "/data/flax/lib/python3.8/site-packages/jax/core.py", line 633, in process_call
24
+ return primitive.impl(f, *tracers, **params)
25
+ File "/data/flax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 778, in xla_pmap_impl
26
+ return compiled_fun(*args)
27
+ File "/data/flax/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
28
+ return func(*args, **kwargs)
29
+ File "/data/flax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1502, in execute_replicated
30
+ out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
31
+ jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: RESOURCE_EXHAUSTED: Attempting to reserve 12.83G at the bottom of memory. That was not possible. There are 13.18G free, 0B reserved, and 12.71G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
32
+ The stack trace below excludes JAX-internal frames.
33
+ The preceding is the original exception that occurred, unmodified.
34
+ --------------------
35
+ The above exception was the direct cause of the following exception:
36
+ Traceback (most recent call last):
37
+ File "run_mlm_flax.py", line 815, in <module>
38
+ main()
39
+ File "run_mlm_flax.py", line 723, in main
40
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
41
+ File "/data/flax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1502, in execute_replicated
42
+ out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
43
+ RuntimeError: RESOURCE_EXHAUSTED: Attempting to reserve 12.83G at the bottom of memory. That was not possible. There are 13.18G free, 0B reserved, and 12.71G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
wandb/run-20220114_212855-32qdb4k5/files/requirements.txt ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.8.1
3
+ aiosignal==1.2.0
4
+ astunparse==1.6.3
5
+ async-timeout==4.0.2
6
+ attrs==21.4.0
7
+ backcall==0.2.0
8
+ cachetools==4.2.4
9
+ certifi==2021.10.8
10
+ charset-normalizer==2.0.10
11
+ chex==0.1.0
12
+ click==8.0.3
13
+ clu==0.0.6
14
+ configparser==5.2.0
15
+ contextlib2==21.6.0
16
+ cycler==0.11.0
17
+ datasets==1.17.1.dev0
18
+ decorator==5.1.0
19
+ dill==0.3.4
20
+ dm-tree==0.1.6
21
+ docker-pycreds==0.4.0
22
+ filelock==3.4.2
23
+ flatbuffers==2.0
24
+ flax==0.3.6
25
+ fonttools==4.28.5
26
+ frozenlist==1.2.0
27
+ fsspec==2021.11.1
28
+ future==0.18.2
29
+ gast==0.4.0
30
+ gitdb==4.0.9
31
+ gitpython==3.1.26
32
+ google-auth-oauthlib==0.4.6
33
+ google-auth==2.3.3
34
+ google-pasta==0.2.0
35
+ googleapis-common-protos==1.54.0
36
+ grpcio==1.43.0
37
+ h5py==3.6.0
38
+ huggingface-hub==0.2.1
39
+ idna==3.3
40
+ importlib-metadata==4.10.0
41
+ importlib-resources==5.4.0
42
+ ipython==7.31.0
43
+ jax==0.2.26
44
+ jaxlib==0.1.75
45
+ jedi==0.18.1
46
+ joblib==1.1.0
47
+ keras-preprocessing==1.1.2
48
+ keras==2.7.0
49
+ kiwisolver==1.3.2
50
+ libclang==12.0.0
51
+ libtpu-nightly==0.1.dev20211208
52
+ markdown==3.3.6
53
+ matplotlib-inline==0.1.3
54
+ matplotlib==3.5.1
55
+ ml-collections==0.1.0
56
+ msgpack==1.0.3
57
+ multidict==5.2.0
58
+ multiprocess==0.70.12.2
59
+ numpy==1.22.0
60
+ oauthlib==3.1.1
61
+ opt-einsum==3.3.0
62
+ optax==0.1.0
63
+ packaging==21.3
64
+ pandas==1.3.5
65
+ parso==0.8.3
66
+ pathtools==0.1.2
67
+ pexpect==4.8.0
68
+ pickleshare==0.7.5
69
+ pillow==9.0.0
70
+ pip==20.0.2
71
+ pkg-resources==0.0.0
72
+ promise==2.3
73
+ prompt-toolkit==3.0.24
74
+ protobuf==3.19.1
75
+ psutil==5.9.0
76
+ ptyprocess==0.7.0
77
+ pyarrow==6.0.1
78
+ pyasn1-modules==0.2.8
79
+ pyasn1==0.4.8
80
+ pygments==2.11.1
81
+ pyparsing==3.0.6
82
+ python-dateutil==2.8.2
83
+ pytz==2021.3
84
+ pyyaml==6.0
85
+ regex==2021.11.10
86
+ requests-oauthlib==1.3.0
87
+ requests==2.27.0
88
+ rsa==4.8
89
+ sacremoses==0.0.46
90
+ scipy==1.7.3
91
+ sentry-sdk==1.5.2
92
+ setuptools==44.0.0
93
+ shortuuid==1.0.8
94
+ six==1.16.0
95
+ smmap==5.0.0
96
+ subprocess32==3.5.4
97
+ tensorboard-data-server==0.6.1
98
+ tensorboard-plugin-wit==1.8.0
99
+ tensorboard==2.7.0
100
+ tensorflow-cpu==2.7.0
101
+ tensorflow-datasets==4.4.0
102
+ tensorflow-estimator==2.7.0
103
+ tensorflow-io-gcs-filesystem==0.23.1
104
+ tensorflow-metadata==1.5.0
105
+ tensorflow==2.7.0
106
+ termcolor==1.1.0
107
+ tokenizers==0.11.2
108
+ toolz==0.11.2
109
+ tqdm==4.62.3
110
+ traitlets==5.1.1
111
+ transformers==4.16.0.dev0
112
+ typing-extensions==3.10.0.2
113
+ urllib3==1.26.7
114
+ wandb==0.12.9
115
+ wcwidth==0.2.5
116
+ werkzeug==2.0.2
117
+ wheel==0.37.1
118
+ wrapt==1.13.3
119
+ xxhash==2.0.2
120
+ yarl==1.7.2
121
+ yaspin==2.1.0
122
+ zipp==3.7.0
wandb/run-20220114_212855-32qdb4k5/files/wandb-metadata.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2022-01-14T21:28:58.974844",
5
+ "startedAt": "2022-01-14T21:28:55.397355",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--output_dir=./",
11
+ "--model_type=roberta",
12
+ "--config_name=roberta-base",
13
+ "--tokenizer_name=NbAiLab/nb-roberta-base",
14
+ "--dataset_name=NbAiLab/NCC",
15
+ "--max_seq_length=128",
16
+ "--weight_decay=0.01",
17
+ "--per_device_train_batch_size=250",
18
+ "--per_device_eval_batch_size=250",
19
+ "--pad_to_max_length",
20
+ "--learning_rate=6e-4",
21
+ "--warmup_steps=10000",
22
+ "--overwrite_output_dir",
23
+ "--num_train_epochs=3",
24
+ "--adam_beta1=0.9",
25
+ "--adam_beta2=0.98",
26
+ "--adam_epsilon=1e-6",
27
+ "--logging_steps=1000",
28
+ "--save_steps=1000",
29
+ "--eval_steps=1000",
30
+ "--do_train",
31
+ "--do_eval",
32
+ "--dtype=bfloat16",
33
+ "--push_to_hub"
34
+ ],
35
+ "state": "running",
36
+ "program": "run_mlm_flax.py",
37
+ "codePath": "run_mlm_flax.py",
38
+ "git": {
39
+ "remote": "https://huggingface.co/versae/roberta-base-ncc",
40
+ "commit": "502df078f73cf93ca9380fcac1c9b9c7598a445f"
41
+ },
42
+ "email": "versae@gmail.com",
43
+ "root": "/data/roberta-base-ncc",
44
+ "host": "t1v-n-eedfb410-w-0",
45
+ "username": "javierr",
46
+ "executable": "/data/flax/bin/python"
47
+ }
wandb/run-20220114_212855-32qdb4k5/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_wandb": {"runtime": 200}}
wandb/run-20220114_212855-32qdb4k5/logs/debug-internal.log ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2022-01-14 21:28:56,265 INFO MainThread:8253 [internal.py:wandb_internal():87] W&B internal server running at pid: 8253, started at: 2022-01-14 21:28:56.265129
2
+ 2022-01-14 21:28:56,268 DEBUG SenderThread:8253 [sender.py:send():234] send: header
3
+ 2022-01-14 21:28:56,268 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: check_version
4
+ 2022-01-14 21:28:56,268 INFO WriterThread:8253 [datastore.py:open_for_write():77] open: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/run-32qdb4k5.wandb
5
+ 2022-01-14 21:28:56,268 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: check_version
6
+ 2022-01-14 21:28:56,352 DEBUG SenderThread:8253 [sender.py:send():234] send: run
7
+ 2022-01-14 21:28:56,515 INFO SenderThread:8253 [dir_watcher.py:__init__():169] watching files in: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files
8
+ 2022-01-14 21:28:56,515 INFO SenderThread:8253 [sender.py:_start_run_threads():804] run started: 32qdb4k5 with start time 1642195735
9
+ 2022-01-14 21:28:56,515 DEBUG SenderThread:8253 [sender.py:send():234] send: summary
10
+ 2022-01-14 21:28:56,515 INFO SenderThread:8253 [sender.py:_save_file():939] saving file wandb-summary.json with policy end
11
+ 2022-01-14 21:28:56,515 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: run_start
12
+ 2022-01-14 21:28:57,561 INFO Thread-8 :8253 [dir_watcher.py:_on_file_created():217] file/dir created: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/wandb-summary.json
13
+ 2022-01-14 21:28:58,974 DEBUG HandlerThread:8253 [meta.py:__init__():40] meta init
14
+ 2022-01-14 21:28:58,974 DEBUG HandlerThread:8253 [meta.py:__init__():54] meta init done
15
+ 2022-01-14 21:28:58,974 DEBUG HandlerThread:8253 [meta.py:probe():214] probe
16
+ 2022-01-14 21:28:58,975 DEBUG HandlerThread:8253 [meta.py:_setup_git():204] setup git
17
+ 2022-01-14 21:28:59,006 DEBUG HandlerThread:8253 [meta.py:_setup_git():211] setup git done
18
+ 2022-01-14 21:28:59,006 DEBUG HandlerThread:8253 [meta.py:_save_code():92] save code
19
+ 2022-01-14 21:28:59,018 DEBUG HandlerThread:8253 [meta.py:_save_code():113] save code done
20
+ 2022-01-14 21:28:59,018 DEBUG HandlerThread:8253 [meta.py:_save_patches():130] save patches
21
+ 2022-01-14 21:28:59,561 INFO Thread-8 :8253 [dir_watcher.py:_on_file_created():217] file/dir created: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/code/run_mlm_flax.py
22
+ 2022-01-14 21:28:59,562 INFO Thread-8 :8253 [dir_watcher.py:_on_file_created():217] file/dir created: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/code
23
+ 2022-01-14 21:29:01,562 INFO Thread-8 :8253 [dir_watcher.py:_on_file_created():217] file/dir created: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/output.log
24
+ 2022-01-14 21:29:03,563 INFO Thread-8 :8253 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/output.log
25
+ 2022-01-14 21:29:03,563 INFO Thread-8 :8253 [dir_watcher.py:_on_file_created():217] file/dir created: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/diff.patch
26
+ 2022-01-14 21:29:08,274 ERROR HandlerThread:8253 [meta.py:_save_patches():171] Error generating diff: Command '['git', 'diff', '--submodule=diff', 'HEAD']' timed out after 5 seconds
27
+ 2022-01-14 21:29:08,274 DEBUG HandlerThread:8253 [meta.py:_save_patches():172] save patches done
28
+ 2022-01-14 21:29:08,274 DEBUG HandlerThread:8253 [meta.py:_save_pip():58] save pip
29
+ 2022-01-14 21:29:08,275 DEBUG HandlerThread:8253 [meta.py:_save_pip():72] save pip done
30
+ 2022-01-14 21:29:08,275 DEBUG HandlerThread:8253 [meta.py:probe():252] probe done
31
+ 2022-01-14 21:29:08,283 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: stop_status
32
+ 2022-01-14 21:29:08,284 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: stop_status
33
+ 2022-01-14 21:29:08,566 DEBUG SenderThread:8253 [sender.py:send():234] send: config
34
+ 2022-01-14 21:29:08,566 DEBUG SenderThread:8253 [sender.py:send():234] send: config
35
+ 2022-01-14 21:29:08,566 DEBUG SenderThread:8253 [sender.py:send():234] send: config
36
+ 2022-01-14 21:29:08,567 DEBUG SenderThread:8253 [sender.py:send():234] send: files
37
+ 2022-01-14 21:29:08,567 INFO SenderThread:8253 [sender.py:_save_file():939] saving file wandb-metadata.json with policy now
38
+ 2022-01-14 21:29:08,567 INFO SenderThread:8253 [sender.py:_save_file():939] saving file code/run_mlm_flax.py with policy now
39
+ 2022-01-14 21:29:08,571 INFO Thread-8 :8253 [dir_watcher.py:_on_file_created():217] file/dir created: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/requirements.txt
40
+ 2022-01-14 21:29:08,571 INFO Thread-8 :8253 [dir_watcher.py:_on_file_created():217] file/dir created: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/wandb-metadata.json
41
+ 2022-01-14 21:29:09,032 INFO Thread-12 :8253 [upload_job.py:push():137] Uploaded file /tmp/tmpdg54qv_0wandb/w1tibuxq-code/run_mlm_flax.py
42
+ 2022-01-14 21:29:09,069 INFO Thread-11 :8253 [upload_job.py:push():137] Uploaded file /tmp/tmpdg54qv_0wandb/35h4ryp5-wandb-metadata.json
43
+ 2022-01-14 21:29:09,571 INFO Thread-8 :8253 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/output.log
44
+ 2022-01-14 21:29:21,520 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: stop_status
45
+ 2022-01-14 21:29:21,521 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: stop_status
46
+ 2022-01-14 21:29:27,049 DEBUG SenderThread:8253 [sender.py:send():234] send: stats
47
+ 2022-01-14 21:29:27,579 INFO Thread-8 :8253 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/config.yaml
48
+ 2022-01-14 21:29:36,656 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: stop_status
49
+ 2022-01-14 21:29:36,657 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: stop_status
50
+ 2022-01-14 21:29:51,794 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: stop_status
51
+ 2022-01-14 21:29:51,795 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: stop_status
52
+ 2022-01-14 21:29:57,119 DEBUG SenderThread:8253 [sender.py:send():234] send: stats
53
+ 2022-01-14 21:29:58,591 INFO Thread-8 :8253 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/output.log
54
+ 2022-01-14 21:30:06,945 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: stop_status
55
+ 2022-01-14 21:30:06,945 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: stop_status
56
+ 2022-01-14 21:30:22,126 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: stop_status
57
+ 2022-01-14 21:30:22,127 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: stop_status
58
+ 2022-01-14 21:30:27,189 DEBUG SenderThread:8253 [sender.py:send():234] send: stats
59
+ 2022-01-14 21:30:37,330 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: stop_status
60
+ 2022-01-14 21:30:37,330 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: stop_status
61
+ 2022-01-14 21:30:52,532 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: stop_status
62
+ 2022-01-14 21:30:52,532 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: stop_status
63
+ 2022-01-14 21:30:57,257 DEBUG SenderThread:8253 [sender.py:send():234] send: stats
64
+ 2022-01-14 21:31:07,691 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: stop_status
65
+ 2022-01-14 21:31:07,692 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: stop_status
66
+ 2022-01-14 21:31:22,944 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: stop_status
67
+ 2022-01-14 21:31:22,945 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: stop_status
68
+ 2022-01-14 21:31:27,323 DEBUG SenderThread:8253 [sender.py:send():234] send: stats
69
+ 2022-01-14 21:31:38,085 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: stop_status
70
+ 2022-01-14 21:31:38,086 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: stop_status
71
+ 2022-01-14 21:31:53,231 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: stop_status
72
+ 2022-01-14 21:31:53,231 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: stop_status
73
+ 2022-01-14 21:31:57,395 DEBUG SenderThread:8253 [sender.py:send():234] send: stats
74
+ 2022-01-14 21:32:08,366 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: stop_status
75
+ 2022-01-14 21:32:08,367 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: stop_status
76
+ 2022-01-14 21:32:16,649 INFO Thread-8 :8253 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/output.log
77
+ 2022-01-14 21:32:16,893 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: poll_exit
78
+ 2022-01-14 21:32:16,893 DEBUG SenderThread:8253 [sender.py:send():234] send: telemetry
79
+ 2022-01-14 21:32:16,893 DEBUG SenderThread:8253 [sender.py:send():234] send: exit
80
+ 2022-01-14 21:32:16,893 INFO SenderThread:8253 [sender.py:send_exit():366] handling exit code: 1
81
+ 2022-01-14 21:32:16,894 INFO SenderThread:8253 [sender.py:send_exit():368] handling runtime: 200
82
+ 2022-01-14 21:32:16,894 INFO SenderThread:8253 [sender.py:_save_file():939] saving file wandb-summary.json with policy end
83
+ 2022-01-14 21:32:16,894 INFO SenderThread:8253 [sender.py:send_exit():374] send defer
84
+ 2022-01-14 21:32:16,894 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: poll_exit
85
+ 2022-01-14 21:32:16,895 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: defer
86
+ 2022-01-14 21:32:16,895 INFO HandlerThread:8253 [handler.py:handle_request_defer():147] handle defer: 0
87
+ 2022-01-14 21:32:16,895 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: defer
88
+ 2022-01-14 21:32:16,895 INFO SenderThread:8253 [sender.py:send_request_defer():383] handle sender defer: 0
89
+ 2022-01-14 21:32:16,895 INFO SenderThread:8253 [sender.py:transition_state():387] send defer: 1
90
+ 2022-01-14 21:32:16,896 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: defer
91
+ 2022-01-14 21:32:16,896 INFO HandlerThread:8253 [handler.py:handle_request_defer():147] handle defer: 1
92
+ 2022-01-14 21:32:16,941 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: defer
93
+ 2022-01-14 21:32:16,941 INFO SenderThread:8253 [sender.py:send_request_defer():383] handle sender defer: 1
94
+ 2022-01-14 21:32:16,941 INFO SenderThread:8253 [sender.py:transition_state():387] send defer: 2
95
+ 2022-01-14 21:32:16,941 DEBUG SenderThread:8253 [sender.py:send():234] send: stats
96
+ 2022-01-14 21:32:16,942 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: defer
97
+ 2022-01-14 21:32:16,942 INFO HandlerThread:8253 [handler.py:handle_request_defer():147] handle defer: 2
98
+ 2022-01-14 21:32:16,942 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: defer
99
+ 2022-01-14 21:32:16,942 INFO SenderThread:8253 [sender.py:send_request_defer():383] handle sender defer: 2
100
+ 2022-01-14 21:32:16,942 INFO SenderThread:8253 [sender.py:transition_state():387] send defer: 3
101
+ 2022-01-14 21:32:16,942 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: defer
102
+ 2022-01-14 21:32:16,942 INFO HandlerThread:8253 [handler.py:handle_request_defer():147] handle defer: 3
103
+ 2022-01-14 21:32:16,943 DEBUG SenderThread:8253 [sender.py:send():234] send: summary
104
+ 2022-01-14 21:32:16,943 INFO SenderThread:8253 [sender.py:_save_file():939] saving file wandb-summary.json with policy end
105
+ 2022-01-14 21:32:16,943 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: defer
106
+ 2022-01-14 21:32:16,943 INFO SenderThread:8253 [sender.py:send_request_defer():383] handle sender defer: 3
107
+ 2022-01-14 21:32:16,943 INFO SenderThread:8253 [sender.py:transition_state():387] send defer: 4
108
+ 2022-01-14 21:32:16,943 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: defer
109
+ 2022-01-14 21:32:16,943 INFO HandlerThread:8253 [handler.py:handle_request_defer():147] handle defer: 4
110
+ 2022-01-14 21:32:16,944 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: defer
111
+ 2022-01-14 21:32:16,944 INFO SenderThread:8253 [sender.py:send_request_defer():383] handle sender defer: 4
112
+ 2022-01-14 21:32:16,997 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: poll_exit
113
+ 2022-01-14 21:32:17,650 INFO Thread-8 :8253 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/output.log
114
+ 2022-01-14 21:32:17,650 INFO Thread-8 :8253 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/wandb-summary.json
115
+ 2022-01-14 21:32:17,685 INFO SenderThread:8253 [sender.py:transition_state():387] send defer: 5
116
+ 2022-01-14 21:32:17,686 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: poll_exit
117
+ 2022-01-14 21:32:17,686 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: defer
118
+ 2022-01-14 21:32:17,686 INFO HandlerThread:8253 [handler.py:handle_request_defer():147] handle defer: 5
119
+ 2022-01-14 21:32:17,686 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: defer
120
+ 2022-01-14 21:32:17,686 INFO SenderThread:8253 [sender.py:send_request_defer():383] handle sender defer: 5
121
+ 2022-01-14 21:32:17,687 INFO SenderThread:8253 [dir_watcher.py:finish():283] shutting down directory watcher
122
+ 2022-01-14 21:32:17,787 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: poll_exit
123
+ 2022-01-14 21:32:18,650 INFO Thread-8 :8253 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/config.yaml
124
+ 2022-01-14 21:32:18,651 INFO SenderThread:8253 [dir_watcher.py:finish():313] scan: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files
125
+ 2022-01-14 21:32:18,651 INFO SenderThread:8253 [dir_watcher.py:finish():327] scan save: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/config.yaml config.yaml
126
+ 2022-01-14 21:32:18,651 INFO SenderThread:8253 [dir_watcher.py:finish():327] scan save: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/diff.patch diff.patch
127
+ 2022-01-14 21:32:18,651 INFO SenderThread:8253 [dir_watcher.py:finish():327] scan save: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/requirements.txt requirements.txt
128
+ 2022-01-14 21:32:18,652 INFO SenderThread:8253 [dir_watcher.py:finish():327] scan save: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/output.log output.log
129
+ 2022-01-14 21:32:18,652 INFO SenderThread:8253 [dir_watcher.py:finish():327] scan save: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/wandb-summary.json wandb-summary.json
130
+ 2022-01-14 21:32:18,652 INFO SenderThread:8253 [dir_watcher.py:finish():327] scan save: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/wandb-metadata.json wandb-metadata.json
131
+ 2022-01-14 21:32:18,656 INFO SenderThread:8253 [dir_watcher.py:finish():327] scan save: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/code/run_mlm_flax.py code/run_mlm_flax.py
132
+ 2022-01-14 21:32:18,656 INFO SenderThread:8253 [sender.py:transition_state():387] send defer: 6
133
+ 2022-01-14 21:32:18,656 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: poll_exit
134
+ 2022-01-14 21:32:18,657 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: defer
135
+ 2022-01-14 21:32:18,657 INFO HandlerThread:8253 [handler.py:handle_request_defer():147] handle defer: 6
136
+ 2022-01-14 21:32:18,662 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: defer
137
+ 2022-01-14 21:32:18,665 INFO SenderThread:8253 [sender.py:send_request_defer():383] handle sender defer: 6
138
+ 2022-01-14 21:32:18,665 INFO SenderThread:8253 [file_pusher.py:finish():177] shutting down file pusher
139
+ 2022-01-14 21:32:18,757 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: poll_exit
140
+ 2022-01-14 21:32:18,758 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: poll_exit
141
+ 2022-01-14 21:32:18,859 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: poll_exit
142
+ 2022-01-14 21:32:18,860 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: poll_exit
143
+ 2022-01-14 21:32:18,961 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: poll_exit
144
+ 2022-01-14 21:32:18,962 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: poll_exit
145
+ 2022-01-14 21:32:19,063 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: poll_exit
146
+ 2022-01-14 21:32:19,064 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: poll_exit
147
+ 2022-01-14 21:32:19,139 INFO Thread-15 :8253 [upload_job.py:push():137] Uploaded file /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/output.log
148
+ 2022-01-14 21:32:19,148 INFO Thread-14 :8253 [upload_job.py:push():137] Uploaded file /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/requirements.txt
149
+ 2022-01-14 21:32:19,165 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: poll_exit
150
+ 2022-01-14 21:32:19,165 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: poll_exit
151
+ 2022-01-14 21:32:19,171 INFO Thread-13 :8253 [upload_job.py:push():137] Uploaded file /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/config.yaml
152
+ 2022-01-14 21:32:19,267 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: poll_exit
153
+ 2022-01-14 21:32:19,267 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: poll_exit
154
+ 2022-01-14 21:32:19,288 INFO Thread-16 :8253 [upload_job.py:push():137] Uploaded file /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/files/wandb-summary.json
155
+ 2022-01-14 21:32:19,370 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: poll_exit
156
+ 2022-01-14 21:32:19,370 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: poll_exit
157
+ 2022-01-14 21:32:19,472 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: poll_exit
158
+ 2022-01-14 21:32:19,472 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: poll_exit
159
+ 2022-01-14 21:32:19,489 INFO Thread-7 :8253 [sender.py:transition_state():387] send defer: 7
160
+ 2022-01-14 21:32:19,489 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: defer
161
+ 2022-01-14 21:32:19,490 INFO HandlerThread:8253 [handler.py:handle_request_defer():147] handle defer: 7
162
+ 2022-01-14 21:32:19,490 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: defer
163
+ 2022-01-14 21:32:19,490 INFO SenderThread:8253 [sender.py:send_request_defer():383] handle sender defer: 7
164
+ 2022-01-14 21:32:19,573 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: poll_exit
165
+ 2022-01-14 21:32:19,915 INFO SenderThread:8253 [sender.py:transition_state():387] send defer: 8
166
+ 2022-01-14 21:32:19,916 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: poll_exit
167
+ 2022-01-14 21:32:19,916 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: defer
168
+ 2022-01-14 21:32:19,916 INFO HandlerThread:8253 [handler.py:handle_request_defer():147] handle defer: 8
169
+ 2022-01-14 21:32:19,917 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: defer
170
+ 2022-01-14 21:32:19,917 INFO SenderThread:8253 [sender.py:send_request_defer():383] handle sender defer: 8
171
+ 2022-01-14 21:32:19,917 INFO SenderThread:8253 [sender.py:transition_state():387] send defer: 9
172
+ 2022-01-14 21:32:19,917 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: defer
173
+ 2022-01-14 21:32:19,917 INFO HandlerThread:8253 [handler.py:handle_request_defer():147] handle defer: 9
174
+ 2022-01-14 21:32:19,918 DEBUG SenderThread:8253 [sender.py:send():234] send: final
175
+ 2022-01-14 21:32:19,918 DEBUG SenderThread:8253 [sender.py:send():234] send: footer
176
+ 2022-01-14 21:32:19,918 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: defer
177
+ 2022-01-14 21:32:19,918 INFO SenderThread:8253 [sender.py:send_request_defer():383] handle sender defer: 9
178
+ 2022-01-14 21:32:20,017 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: poll_exit
179
+ 2022-01-14 21:32:20,018 DEBUG SenderThread:8253 [sender.py:send_request():248] send_request: poll_exit
180
+ 2022-01-14 21:32:20,018 INFO SenderThread:8253 [file_pusher.py:join():182] waiting for file pusher
181
+ 2022-01-14 21:32:20,278 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: get_summary
182
+ 2022-01-14 21:32:20,278 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: sampled_history
183
+ 2022-01-14 21:32:20,279 DEBUG HandlerThread:8253 [handler.py:handle_request():130] handle_request: shutdown
184
+ 2022-01-14 21:32:20,279 INFO HandlerThread:8253 [handler.py:finish():731] shutting down handler
185
+ 2022-01-14 21:32:20,918 INFO WriterThread:8253 [datastore.py:close():281] close: /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/run-32qdb4k5.wandb
186
+ 2022-01-14 21:32:21,277 INFO SenderThread:8253 [sender.py:finish():1070] shutting down sender
187
+ 2022-01-14 21:32:21,277 INFO SenderThread:8253 [file_pusher.py:finish():177] shutting down file pusher
188
+ 2022-01-14 21:32:21,277 INFO SenderThread:8253 [file_pusher.py:join():182] waiting for file pusher
189
+ 2022-01-14 21:32:21,279 INFO MainThread:8253 [internal.py:handle_exit():77] Internal process exited
wandb/run-20220114_212855-32qdb4k5/logs/debug.log ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2022-01-14 21:28:55,408 INFO MainThread:5000 [wandb_setup.py:_flush():71] setting env: {}
2
+ 2022-01-14 21:28:55,408 INFO MainThread:5000 [wandb_setup.py:_flush():71] setting login settings: {}
3
+ 2022-01-14 21:28:55,408 INFO MainThread:5000 [wandb_init.py:_log_setup():371] Logging user logs to /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/logs/debug.log
4
+ 2022-01-14 21:28:55,408 INFO MainThread:5000 [wandb_init.py:_log_setup():372] Logging internal logs to /data/roberta-base-ncc/wandb/run-20220114_212855-32qdb4k5/logs/debug-internal.log
5
+ 2022-01-14 21:28:55,408 INFO MainThread:5000 [wandb_init.py:init():404] calling init triggers
6
+ 2022-01-14 21:28:55,408 INFO MainThread:5000 [wandb_init.py:init():409] wandb.init called with sweep_config: {}
7
+ config: {}
8
+ 2022-01-14 21:28:55,409 INFO MainThread:5000 [wandb_init.py:init():460] starting backend
9
+ 2022-01-14 21:28:55,409 INFO MainThread:5000 [backend.py:_multiprocessing_setup():99] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
10
+ 2022-01-14 21:28:55,461 INFO MainThread:5000 [backend.py:ensure_launched():216] starting backend process...
11
+ 2022-01-14 21:28:55,488 INFO MainThread:5000 [backend.py:ensure_launched():221] started backend process with pid: 8253
12
+ 2022-01-14 21:28:55,490 INFO MainThread:5000 [wandb_init.py:init():469] backend started and connected
13
+ 2022-01-14 21:28:55,501 INFO MainThread:5000 [wandb_init.py:init():533] updated telemetry
14
+ 2022-01-14 21:28:55,563 INFO MainThread:5000 [wandb_init.py:init():563] communicating current version
15
+ 2022-01-14 21:28:56,351 INFO MainThread:5000 [wandb_init.py:init():568] got version response
16
+ 2022-01-14 21:28:56,351 INFO MainThread:5000 [wandb_init.py:init():578] communicating run to backend with 30 second timeout
17
+ 2022-01-14 21:28:56,515 INFO MainThread:5000 [wandb_init.py:init():606] starting run threads in backend
18
+ 2022-01-14 21:29:01,520 INFO MainThread:5000 [wandb_run.py:_console_start():1810] atexit reg
19
+ 2022-01-14 21:29:01,520 INFO MainThread:5000 [wandb_run.py:_redirect():1684] redirect: SettingsConsole.REDIRECT
20
+ 2022-01-14 21:29:01,520 INFO MainThread:5000 [wandb_run.py:_redirect():1689] Redirecting console.
21
+ 2022-01-14 21:29:01,523 INFO MainThread:5000 [wandb_run.py:_redirect():1745] Redirects installed.
22
+ 2022-01-14 21:29:01,523 INFO MainThread:5000 [wandb_init.py:init():633] run started, returning control to user process
23
+ 2022-01-14 21:29:01,523 INFO MainThread:5000 [wandb_run.py:_config_callback():956] config_cb None None {'output_dir': './', 'overwrite_output_dir': True, 'do_train': True, 'do_eval': True, 'per_device_train_batch_size': 250, 'per_device_eval_batch_size': 250, 'learning_rate': 0.0006, 'weight_decay': 0.01, 'adam_beta1': 0.9, 'adam_beta2': 0.98, 'adam_epsilon': 1e-06, 'adafactor': False, 'num_train_epochs': 3.0, 'warmup_steps': 10000, 'logging_steps': 1000, 'save_steps': 1000, 'eval_steps': 1000, 'seed': 42, 'push_to_hub': True, 'hub_model_id': None, 'hub_token': None}
24
+ 2022-01-14 21:29:01,524 INFO MainThread:5000 [wandb_run.py:_config_callback():956] config_cb None None {'model_name_or_path': None, 'model_type': 'roberta', 'config_name': 'roberta-base', 'tokenizer_name': 'NbAiLab/nb-roberta-base', 'cache_dir': None, 'use_fast_tokenizer': True, 'dtype': 'bfloat16'}
25
+ 2022-01-14 21:29:01,524 INFO MainThread:5000 [wandb_run.py:_config_callback():956] config_cb None None {'dataset_name': 'NbAiLab/NCC', 'dataset_config_name': None, 'train_file': None, 'validation_file': None, 'train_ref_file': None, 'validation_ref_file': None, 'overwrite_cache': False, 'validation_split_percentage': 5, 'max_seq_length': 128, 'preprocessing_num_workers': None, 'mlm_probability': 0.15, 'pad_to_max_length': True, 'line_by_line': False}
26
+ 2022-01-14 21:32:14,189 INFO MainThread:5000 [wandb_run.py:_atexit_cleanup():1780] got exitcode: 1
27
+ 2022-01-14 21:32:14,192 INFO MainThread:5000 [wandb_run.py:_restore():1752] restore
28
+ 2022-01-14 21:32:16,895 INFO MainThread:5000 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
29
+ wandb_count: 1
30
+ other_count: 1
31
+ }
32
+ pusher_stats {
33
+ uploaded_bytes: 37446
34
+ total_bytes: 37446
35
+ }
36
+
37
+ 2022-01-14 21:32:17,686 INFO MainThread:5000 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
38
+ wandb_count: 1
39
+ other_count: 1
40
+ }
41
+ pusher_stats {
42
+ uploaded_bytes: 37446
43
+ total_bytes: 37446
44
+ }
45
+
46
+ 2022-01-14 21:32:18,657 INFO MainThread:5000 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
47
+ wandb_count: 5
48
+ other_count: 1
49
+ }
50
+ pusher_stats {
51
+ uploaded_bytes: 37446
52
+ total_bytes: 45931
53
+ }
54
+
55
+ 2022-01-14 21:32:18,759 INFO MainThread:5000 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
56
+ wandb_count: 5
57
+ other_count: 1
58
+ }
59
+ pusher_stats {
60
+ uploaded_bytes: 37446
61
+ total_bytes: 45931
62
+ }
63
+
64
+ 2022-01-14 21:32:18,861 INFO MainThread:5000 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
65
+ wandb_count: 5
66
+ other_count: 1
67
+ }
68
+ pusher_stats {
69
+ uploaded_bytes: 45903
70
+ total_bytes: 45931
71
+ }
72
+
73
+ 2022-01-14 21:32:18,962 INFO MainThread:5000 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
74
+ wandb_count: 5
75
+ other_count: 1
76
+ }
77
+ pusher_stats {
78
+ uploaded_bytes: 45903
79
+ total_bytes: 45931
80
+ }
81
+
82
+ 2022-01-14 21:32:19,064 INFO MainThread:5000 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
83
+ wandb_count: 5
84
+ other_count: 1
85
+ }
86
+ pusher_stats {
87
+ uploaded_bytes: 45931
88
+ total_bytes: 45931
89
+ }
90
+
91
+ 2022-01-14 21:32:19,166 INFO MainThread:5000 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
92
+ wandb_count: 5
93
+ other_count: 1
94
+ }
95
+ pusher_stats {
96
+ uploaded_bytes: 45931
97
+ total_bytes: 45931
98
+ }
99
+
100
+ 2022-01-14 21:32:19,269 INFO MainThread:5000 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
101
+ wandb_count: 5
102
+ other_count: 1
103
+ }
104
+ pusher_stats {
105
+ uploaded_bytes: 45931
106
+ total_bytes: 45931
107
+ }
108
+
109
+ 2022-01-14 21:32:19,371 INFO MainThread:5000 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
110
+ wandb_count: 5
111
+ other_count: 1
112
+ }
113
+ pusher_stats {
114
+ uploaded_bytes: 45931
115
+ total_bytes: 45931
116
+ }
117
+
118
+ 2022-01-14 21:32:19,473 INFO MainThread:5000 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
119
+ wandb_count: 5
120
+ other_count: 1
121
+ }
122
+ pusher_stats {
123
+ uploaded_bytes: 45931
124
+ total_bytes: 45931
125
+ }
126
+
127
+ 2022-01-14 21:32:19,917 INFO MainThread:5000 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
128
+ wandb_count: 5
129
+ other_count: 1
130
+ }
131
+ pusher_stats {
132
+ uploaded_bytes: 45931
133
+ total_bytes: 45931
134
+ }
135
+
136
+ 2022-01-14 21:32:20,277 INFO MainThread:5000 [wandb_run.py:_wait_for_finish():1912] got exit ret: done: true
137
+ exit_result {
138
+ }
139
+ file_counts {
140
+ wandb_count: 5
141
+ other_count: 1
142
+ }
143
+ pusher_stats {
144
+ uploaded_bytes: 45931
145
+ total_bytes: 45931
146
+ }
147
+ local_info {
148
+ }
149
+
150
+ 2022-01-14 21:32:23,445 INFO MainThread:5000 [wandb_run.py:_append_files():2180] logging synced files
wandb/run-20220114_212855-32qdb4k5/run-32qdb4k5.wandb ADDED
Binary file (7.7 kB). View file
 
wandb/run-20220114_221533-24dma583/files/code/run_mlm_flax.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=fill-mask
22
+ """
23
+ import json
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ import time
29
+ from dataclasses import asdict, dataclass, field
30
+ from enum import Enum
31
+ from itertools import chain
32
+
33
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
34
+ from pathlib import Path
35
+ from typing import Dict, List, Optional, Tuple
36
+
37
+ import numpy as np
38
+ from datasets import load_dataset
39
+ from tqdm import tqdm
40
+
41
+ import flax
42
+ import jax
43
+ import jax.numpy as jnp
44
+ import optax
45
+ from flax import jax_utils, traverse_util
46
+ from flax.training import train_state
47
+ from flax.training.common_utils import get_metrics, onehot, shard
48
+ from huggingface_hub import Repository
49
+ from transformers import (
50
+ CONFIG_MAPPING,
51
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
52
+ AutoConfig,
53
+ AutoTokenizer,
54
+ FlaxAutoModelForMaskedLM,
55
+ HfArgumentParser,
56
+ PreTrainedTokenizerBase,
57
+ TensorType,
58
+ is_tensorboard_available,
59
+ set_seed,
60
+ )
61
+ from transformers.file_utils import get_full_repo_name
62
+
63
+
64
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
65
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
66
+
67
+
68
+ @dataclass
69
+ class TrainingArguments:
70
+ output_dir: str = field(
71
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
72
+ )
73
+ overwrite_output_dir: bool = field(
74
+ default=False,
75
+ metadata={
76
+ "help": (
77
+ "Overwrite the content of the output directory. "
78
+ "Use this to continue training if output_dir points to a checkpoint directory."
79
+ )
80
+ },
81
+ )
82
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
83
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
84
+ per_device_train_batch_size: int = field(
85
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
86
+ )
87
+ per_device_eval_batch_size: int = field(
88
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
89
+ )
90
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
91
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
92
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
93
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
94
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
95
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
96
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
97
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
98
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
99
+ save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
100
+ eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
101
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
102
+ push_to_hub: bool = field(
103
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
104
+ )
105
+ hub_model_id: str = field(
106
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
107
+ )
108
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
109
+
110
+ def __post_init__(self):
111
+ if self.output_dir is not None:
112
+ self.output_dir = os.path.expanduser(self.output_dir)
113
+
114
+ def to_dict(self):
115
+ """
116
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
117
+ the token values by removing their value.
118
+ """
119
+ d = asdict(self)
120
+ for k, v in d.items():
121
+ if isinstance(v, Enum):
122
+ d[k] = v.value
123
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
124
+ d[k] = [x.value for x in v]
125
+ if k.endswith("_token"):
126
+ d[k] = f"<{k.upper()}>"
127
+ return d
128
+
129
+
130
+ @dataclass
131
+ class ModelArguments:
132
+ """
133
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
134
+ """
135
+
136
+ model_name_or_path: Optional[str] = field(
137
+ default=None,
138
+ metadata={
139
+ "help": "The model checkpoint for weights initialization."
140
+ "Don't set if you want to train a model from scratch."
141
+ },
142
+ )
143
+ model_type: Optional[str] = field(
144
+ default=None,
145
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
146
+ )
147
+ config_name: Optional[str] = field(
148
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
149
+ )
150
+ tokenizer_name: Optional[str] = field(
151
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
152
+ )
153
+ cache_dir: Optional[str] = field(
154
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
155
+ )
156
+ use_fast_tokenizer: bool = field(
157
+ default=True,
158
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
159
+ )
160
+ dtype: Optional[str] = field(
161
+ default="float32",
162
+ metadata={
163
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
164
+ },
165
+ )
166
+
167
+
168
+ @dataclass
169
+ class DataTrainingArguments:
170
+ """
171
+ Arguments pertaining to what data we are going to input our model for training and eval.
172
+ """
173
+
174
+ dataset_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
176
+ )
177
+ dataset_config_name: Optional[str] = field(
178
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
179
+ )
180
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
181
+ validation_file: Optional[str] = field(
182
+ default=None,
183
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
184
+ )
185
+ train_ref_file: Optional[str] = field(
186
+ default=None,
187
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
188
+ )
189
+ validation_ref_file: Optional[str] = field(
190
+ default=None,
191
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
192
+ )
193
+ overwrite_cache: bool = field(
194
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
195
+ )
196
+ validation_split_percentage: Optional[int] = field(
197
+ default=5,
198
+ metadata={
199
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
200
+ },
201
+ )
202
+ max_seq_length: Optional[int] = field(
203
+ default=None,
204
+ metadata={
205
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
206
+ "than this will be truncated. Default to the max input length of the model."
207
+ },
208
+ )
209
+ preprocessing_num_workers: Optional[int] = field(
210
+ default=None,
211
+ metadata={"help": "The number of processes to use for the preprocessing."},
212
+ )
213
+ mlm_probability: float = field(
214
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
215
+ )
216
+ pad_to_max_length: bool = field(
217
+ default=False,
218
+ metadata={
219
+ "help": "Whether to pad all samples to `max_seq_length`. "
220
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
221
+ },
222
+ )
223
+ line_by_line: bool = field(
224
+ default=False,
225
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
226
+ )
227
+
228
+ def __post_init__(self):
229
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
230
+ raise ValueError("Need either a dataset name or a training/validation file.")
231
+ else:
232
+ if self.train_file is not None:
233
+ extension = self.train_file.split(".")[-1]
234
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
235
+ if self.validation_file is not None:
236
+ extension = self.validation_file.split(".")[-1]
237
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
238
+
239
+
240
+ @flax.struct.dataclass
241
+ class FlaxDataCollatorForLanguageModeling:
242
+ """
243
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
244
+ are not all of the same length.
245
+
246
+ Args:
247
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
248
+ The tokenizer used for encoding the data.
249
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
250
+ The probability with which to (randomly) mask tokens in the input.
251
+
252
+ .. note::
253
+
254
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
255
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
256
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
257
+ argument :obj:`return_special_tokens_mask=True`.
258
+ """
259
+
260
+ tokenizer: PreTrainedTokenizerBase
261
+ mlm_probability: float = 0.15
262
+
263
+ def __post_init__(self):
264
+ if self.tokenizer.mask_token is None:
265
+ raise ValueError(
266
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
267
+ "You should pass `mlm=False` to train on causal language modeling instead."
268
+ )
269
+
270
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
271
+ # Handle dict or lists with proper padding and conversion to tensor.
272
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
273
+
274
+ # If special token mask has been preprocessed, pop it from the dict.
275
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
276
+
277
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
278
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
279
+ )
280
+ return batch
281
+
282
+ def mask_tokens(
283
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
284
+ ) -> Tuple[np.ndarray, np.ndarray]:
285
+ """
286
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
287
+ """
288
+ labels = inputs.copy()
289
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
290
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
291
+ special_tokens_mask = special_tokens_mask.astype("bool")
292
+
293
+ probability_matrix[special_tokens_mask] = 0.0
294
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
295
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
296
+
297
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
298
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
299
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
300
+
301
+ # 10% of the time, we replace masked input tokens with random word
302
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
303
+ indices_random &= masked_indices & ~indices_replaced
304
+
305
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
306
+ inputs[indices_random] = random_words[indices_random]
307
+
308
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
309
+ return inputs, labels
310
+
311
+
312
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
313
+ num_samples = len(samples_idx)
314
+ samples_to_remove = num_samples % batch_size
315
+
316
+ if samples_to_remove != 0:
317
+ samples_idx = samples_idx[:-samples_to_remove]
318
+ sections_split = num_samples // batch_size
319
+ batch_idx = np.split(samples_idx, sections_split)
320
+ return batch_idx
321
+
322
+
323
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
324
+ summary_writer.scalar("train_time", train_time, step)
325
+
326
+ train_metrics = get_metrics(train_metrics)
327
+ for key, vals in train_metrics.items():
328
+ tag = f"train_{key}"
329
+ for i, val in enumerate(vals):
330
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
331
+
332
+
333
+ def write_eval_metric(summary_writer, eval_metrics, step):
334
+ for metric_name, value in eval_metrics.items():
335
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
336
+
337
+
338
+ def main():
339
+ # See all possible arguments in src/transformers/training_args.py
340
+ # or by passing the --help flag to this script.
341
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
342
+
343
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
344
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
345
+ # If we pass only one argument to the script and it's the path to a json file,
346
+ # let's parse it to get our arguments.
347
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
348
+ else:
349
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
350
+
351
+ if (
352
+ os.path.exists(training_args.output_dir)
353
+ and os.listdir(training_args.output_dir)
354
+ and training_args.do_train
355
+ and not training_args.overwrite_output_dir
356
+ ):
357
+ raise ValueError(
358
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
359
+ "Use --overwrite_output_dir to overcome."
360
+ )
361
+
362
+ # Setup logging
363
+ logging.basicConfig(
364
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
365
+ level=logging.INFO,
366
+ datefmt="[%X]",
367
+ )
368
+
369
+ # Log on each process the small summary:
370
+ logger = logging.getLogger(__name__)
371
+
372
+ # Set the verbosity to info of the Transformers logger (on main process only):
373
+ logger.info(f"Training/evaluation parameters {training_args}")
374
+
375
+ # Set seed before initializing model.
376
+ set_seed(training_args.seed)
377
+
378
+ # Handle the repository creation
379
+ if training_args.push_to_hub:
380
+ if training_args.hub_model_id is None:
381
+ repo_name = get_full_repo_name(
382
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
383
+ )
384
+ else:
385
+ repo_name = training_args.hub_model_id
386
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
387
+
388
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
389
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
390
+ # (the dataset will be downloaded automatically from the datasets Hub).
391
+ #
392
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
393
+ # 'text' is found. You can easily tweak this behavior (see below).
394
+ #
395
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
396
+ # download the dataset.
397
+ if data_args.dataset_name is not None:
398
+ # Downloading and loading a dataset from the hub.
399
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
400
+
401
+ if "validation" not in datasets.keys():
402
+ datasets["validation"] = load_dataset(
403
+ data_args.dataset_name,
404
+ data_args.dataset_config_name,
405
+ split=f"train[:{data_args.validation_split_percentage}%]",
406
+ cache_dir=model_args.cache_dir,
407
+ )
408
+ datasets["train"] = load_dataset(
409
+ data_args.dataset_name,
410
+ data_args.dataset_config_name,
411
+ split=f"train[{data_args.validation_split_percentage}%:]",
412
+ cache_dir=model_args.cache_dir,
413
+ )
414
+ else:
415
+ data_files = {}
416
+ if data_args.train_file is not None:
417
+ data_files["train"] = data_args.train_file
418
+ if data_args.validation_file is not None:
419
+ data_files["validation"] = data_args.validation_file
420
+ extension = data_args.train_file.split(".")[-1]
421
+ if extension == "txt":
422
+ extension = "text"
423
+ datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
424
+
425
+ if "validation" not in datasets.keys():
426
+ datasets["validation"] = load_dataset(
427
+ extension,
428
+ data_files=data_files,
429
+ split=f"train[:{data_args.validation_split_percentage}%]",
430
+ cache_dir=model_args.cache_dir,
431
+ )
432
+ datasets["train"] = load_dataset(
433
+ extension,
434
+ data_files=data_files,
435
+ split=f"train[{data_args.validation_split_percentage}%:]",
436
+ cache_dir=model_args.cache_dir,
437
+ )
438
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
439
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
440
+
441
+ # Load pretrained model and tokenizer
442
+
443
+ # Distributed training:
444
+ # The .from_pretrained methods guarantee that only one local process can concurrently
445
+ # download model & vocab.
446
+ if model_args.config_name:
447
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
448
+ elif model_args.model_name_or_path:
449
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
450
+ else:
451
+ config = CONFIG_MAPPING[model_args.model_type]()
452
+ logger.warning("You are instantiating a new config instance from scratch.")
453
+
454
+ if model_args.tokenizer_name:
455
+ tokenizer = AutoTokenizer.from_pretrained(
456
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
457
+ )
458
+ elif model_args.model_name_or_path:
459
+ tokenizer = AutoTokenizer.from_pretrained(
460
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
461
+ )
462
+ else:
463
+ raise ValueError(
464
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
465
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
466
+ )
467
+
468
+ # Preprocessing the datasets.
469
+ # First we tokenize all the texts.
470
+ if training_args.do_train:
471
+ column_names = datasets["train"].column_names
472
+ else:
473
+ column_names = datasets["validation"].column_names
474
+ text_column_name = "text" if "text" in column_names else column_names[0]
475
+
476
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
477
+
478
+ if data_args.line_by_line:
479
+ # When using line_by_line, we just tokenize each nonempty line.
480
+ padding = "max_length" if data_args.pad_to_max_length else False
481
+
482
+ def tokenize_function(examples):
483
+ # Remove empty lines
484
+ examples = [line for line in examples if len(line) > 0 and not line.isspace()]
485
+ return tokenizer(
486
+ examples,
487
+ return_special_tokens_mask=True,
488
+ padding=padding,
489
+ truncation=True,
490
+ max_length=max_seq_length,
491
+ )
492
+
493
+ tokenized_datasets = datasets.map(
494
+ tokenize_function,
495
+ input_columns=[text_column_name],
496
+ batched=True,
497
+ num_proc=data_args.preprocessing_num_workers,
498
+ remove_columns=column_names,
499
+ load_from_cache_file=not data_args.overwrite_cache,
500
+ )
501
+
502
+ else:
503
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
504
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
505
+ # efficient when it receives the `special_tokens_mask`.
506
+ def tokenize_function(examples):
507
+ return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
508
+
509
+ tokenized_datasets = datasets.map(
510
+ tokenize_function,
511
+ batched=True,
512
+ num_proc=data_args.preprocessing_num_workers,
513
+ remove_columns=column_names,
514
+ load_from_cache_file=not data_args.overwrite_cache,
515
+ )
516
+
517
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
518
+ # max_seq_length.
519
+ def group_texts(examples):
520
+ # Concatenate all texts.
521
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
522
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
523
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
524
+ # customize this part to your needs.
525
+ if total_length >= max_seq_length:
526
+ total_length = (total_length // max_seq_length) * max_seq_length
527
+ # Split by chunks of max_len.
528
+ result = {
529
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
530
+ for k, t in concatenated_examples.items()
531
+ }
532
+ return result
533
+
534
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
535
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
536
+ # might be slower to preprocess.
537
+ #
538
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
539
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
540
+ tokenized_datasets = tokenized_datasets.map(
541
+ group_texts,
542
+ batched=True,
543
+ num_proc=data_args.preprocessing_num_workers,
544
+ load_from_cache_file=not data_args.overwrite_cache,
545
+ )
546
+
547
+ # Enable tensorboard only on the master node
548
+ has_tensorboard = is_tensorboard_available()
549
+ if has_tensorboard and jax.process_index() == 0:
550
+ try:
551
+ # Enable Weight&Biases
552
+ import wandb
553
+ wandb.init(
554
+ entity='versae',
555
+ project='roberta-base-ncc',
556
+ sync_tensorboard=False,
557
+ )
558
+ wandb.config.update(training_args)
559
+ wandb.config.update(model_args)
560
+ wandb.config.update(data_args)
561
+
562
+ from flax.metrics.tensorboard import SummaryWriter
563
+
564
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
565
+ except ImportError as ie:
566
+ has_tensorboard = False
567
+ logger.warning(
568
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
569
+ )
570
+ else:
571
+ logger.warning(
572
+ "Unable to display metrics through TensorBoard because the package is not installed: "
573
+ "Please run pip install tensorboard to enable."
574
+ )
575
+
576
+ # Data collator
577
+ # This one will take care of randomly masking the tokens.
578
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
579
+
580
+ # Initialize our training
581
+ rng = jax.random.PRNGKey(training_args.seed)
582
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
583
+
584
+ if model_args.model_name_or_path:
585
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
586
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
587
+ )
588
+ else:
589
+ model = FlaxAutoModelForMaskedLM.from_config(
590
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
591
+ )
592
+
593
+ # Store some constant
594
+ num_epochs = int(training_args.num_train_epochs)
595
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
596
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
597
+
598
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
599
+
600
+ # Create learning rate schedule
601
+ warmup_fn = optax.linear_schedule(
602
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
603
+ )
604
+ decay_fn = optax.linear_schedule(
605
+ init_value=training_args.learning_rate,
606
+ end_value=0,
607
+ transition_steps=num_train_steps - training_args.warmup_steps,
608
+ )
609
+ linear_decay_lr_schedule_fn = optax.join_schedules(
610
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
611
+ )
612
+
613
+ # We use Optax's "masking" functionality to not apply weight decay
614
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
615
+ # mask boolean with the same structure as the parameters.
616
+ # The mask is True for parameters that should be decayed.
617
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
618
+ # For other models, one should correct the layer norm parameter naming
619
+ # accordingly.
620
+ def decay_mask_fn(params):
621
+ flat_params = traverse_util.flatten_dict(params)
622
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
623
+ return traverse_util.unflatten_dict(flat_mask)
624
+
625
+ # create adam optimizer
626
+ if training_args.adafactor:
627
+ # We use the default parameters here to initialize adafactor,
628
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
629
+ optimizer = optax.adafactor(
630
+ learning_rate=linear_decay_lr_schedule_fn,
631
+ )
632
+ else:
633
+ optimizer = optax.adamw(
634
+ learning_rate=linear_decay_lr_schedule_fn,
635
+ b1=training_args.adam_beta1,
636
+ b2=training_args.adam_beta2,
637
+ eps=training_args.adam_epsilon,
638
+ weight_decay=training_args.weight_decay,
639
+ mask=decay_mask_fn,
640
+ )
641
+
642
+ # Setup train state
643
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
644
+
645
+ # Define gradient update step fn
646
+ def train_step(state, batch, dropout_rng):
647
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
648
+
649
+ def loss_fn(params):
650
+ labels = batch.pop("labels")
651
+
652
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
653
+
654
+ # compute loss, ignore padded input tokens
655
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
656
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
657
+
658
+ # take average
659
+ loss = loss.sum() / label_mask.sum()
660
+
661
+ return loss
662
+
663
+ grad_fn = jax.value_and_grad(loss_fn)
664
+ loss, grad = grad_fn(state.params)
665
+ grad = jax.lax.pmean(grad, "batch")
666
+ new_state = state.apply_gradients(grads=grad)
667
+
668
+ metrics = jax.lax.pmean(
669
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
670
+ )
671
+
672
+ return new_state, metrics, new_dropout_rng
673
+
674
+ # Create parallel version of the train step
675
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
676
+
677
+ # Define eval fn
678
+ def eval_step(params, batch):
679
+ labels = batch.pop("labels")
680
+
681
+ logits = model(**batch, params=params, train=False)[0]
682
+
683
+ # compute loss, ignore padded input tokens
684
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
685
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
686
+
687
+ # compute accuracy
688
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
689
+
690
+ # summarize metrics
691
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
692
+ metrics = jax.lax.psum(metrics, axis_name="batch")
693
+
694
+ return metrics
695
+
696
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
697
+
698
+ # Replicate the train state on each device
699
+ state = jax_utils.replicate(state)
700
+
701
+ train_time = 0
702
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
703
+ for epoch in epochs:
704
+ # ======================== Training ================================
705
+ train_start = time.time()
706
+ train_metrics = []
707
+
708
+ # Create sampling rng
709
+ rng, input_rng = jax.random.split(rng)
710
+
711
+ # Generate an epoch by shuffling sampling indices from the train dataset
712
+ num_train_samples = len(tokenized_datasets["train"])
713
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
714
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
715
+
716
+ # Gather the indexes for creating the batch and do a training step
717
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
718
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
719
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
720
+
721
+ # Model forward
722
+ model_inputs = shard(model_inputs.data)
723
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
724
+ train_metrics.append(train_metric)
725
+
726
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
727
+
728
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
729
+ # Save metrics
730
+ train_metric = jax_utils.unreplicate(train_metric)
731
+ train_time += time.time() - train_start
732
+ if has_tensorboard and jax.process_index() == 0:
733
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
734
+
735
+ epochs.write(
736
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
737
+ )
738
+
739
+ train_metrics = []
740
+
741
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
742
+ # ======================== Evaluating ==============================
743
+ num_eval_samples = len(tokenized_datasets["validation"])
744
+ eval_samples_idx = jnp.arange(num_eval_samples)
745
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
746
+
747
+ eval_metrics = []
748
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
749
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
750
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
751
+
752
+ # Model forward
753
+ model_inputs = shard(model_inputs.data)
754
+ metrics = p_eval_step(state.params, model_inputs)
755
+ eval_metrics.append(metrics)
756
+
757
+ # normalize eval metrics
758
+ eval_metrics = get_metrics(eval_metrics)
759
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
760
+ eval_normalizer = eval_metrics.pop("normalizer")
761
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
762
+
763
+ # Update progress bar
764
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
765
+
766
+ # Save metrics
767
+ if has_tensorboard and jax.process_index() == 0:
768
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
769
+
770
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
771
+ # save checkpoint after each epoch and push checkpoint to the hub
772
+ if jax.process_index() == 0:
773
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
774
+ model.save_pretrained(training_args.output_dir, params=params)
775
+ tokenizer.save_pretrained(training_args.output_dir)
776
+ if training_args.push_to_hub:
777
+ repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
778
+
779
+ # Eval after training
780
+ if training_args.do_eval:
781
+ num_eval_samples = len(tokenized_datasets["validation"])
782
+ eval_samples_idx = jnp.arange(num_eval_samples)
783
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
784
+
785
+ eval_metrics = []
786
+ for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
787
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
788
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
789
+
790
+ # Model forward
791
+ model_inputs = shard(model_inputs.data)
792
+ metrics = p_eval_step(state.params, model_inputs)
793
+ eval_metrics.append(metrics)
794
+
795
+ # normalize eval metrics
796
+ eval_metrics = get_metrics(eval_metrics)
797
+ eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
798
+ eval_normalizer = eval_metrics.pop("normalizer")
799
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
800
+
801
+ try:
802
+ perplexity = math.exp(eval_metrics["loss"])
803
+ except OverflowError:
804
+ perplexity = float("inf")
805
+ eval_metrics["perplexity"] = perplexity
806
+
807
+ if jax.process_index() == 0:
808
+ eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
809
+ path = os.path.join(training_args.output_dir, "eval_results.json")
810
+ with open(path, "w") as f:
811
+ json.dump(eval_metrics, f, indent=4, sort_keys=True)
812
+
813
+
814
+ if __name__ == "__main__":
815
+ main()
wandb/run-20220114_221533-24dma583/files/config.yaml ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.12.9
7
+ code_path: code/run_mlm_flax.py
8
+ framework: huggingface
9
+ huggingface_version: 4.16.0.dev0
10
+ is_jupyter_run: false
11
+ is_kaggle_kernel: false
12
+ python_version: 3.8.10
13
+ start_time: 1642198533
14
+ t:
15
+ 1:
16
+ - 2
17
+ - 3
18
+ - 11
19
+ - 12
20
+ 2:
21
+ - 2
22
+ - 3
23
+ - 11
24
+ - 12
25
+ 4: 3.8.10
26
+ 5: 0.12.9
27
+ 6: 4.16.0.dev0
28
+ 8:
29
+ - 5
30
+ adafactor:
31
+ desc: null
32
+ value: false
33
+ adam_beta1:
34
+ desc: null
35
+ value: 0.9
36
+ adam_beta2:
37
+ desc: null
38
+ value: 0.98
39
+ adam_epsilon:
40
+ desc: null
41
+ value: 1.0e-06
42
+ cache_dir:
43
+ desc: null
44
+ value: null
45
+ config_name:
46
+ desc: null
47
+ value: roberta-base
48
+ dataset_config_name:
49
+ desc: null
50
+ value: null
51
+ dataset_name:
52
+ desc: null
53
+ value: NbAiLab/NCC
54
+ do_eval:
55
+ desc: null
56
+ value: true
57
+ do_train:
58
+ desc: null
59
+ value: true
60
+ dtype:
61
+ desc: null
62
+ value: bfloat16
63
+ eval_steps:
64
+ desc: null
65
+ value: 1000
66
+ hub_model_id:
67
+ desc: null
68
+ value: null
69
+ hub_token:
70
+ desc: null
71
+ value: null
72
+ learning_rate:
73
+ desc: null
74
+ value: 0.0006
75
+ line_by_line:
76
+ desc: null
77
+ value: false
78
+ logging_steps:
79
+ desc: null
80
+ value: 1000
81
+ max_seq_length:
82
+ desc: null
83
+ value: 128
84
+ mlm_probability:
85
+ desc: null
86
+ value: 0.15
87
+ model_name_or_path:
88
+ desc: null
89
+ value: null
90
+ model_type:
91
+ desc: null
92
+ value: roberta
93
+ num_train_epochs:
94
+ desc: null
95
+ value: 3.0
96
+ output_dir:
97
+ desc: null
98
+ value: ./
99
+ overwrite_cache:
100
+ desc: null
101
+ value: false
102
+ overwrite_output_dir:
103
+ desc: null
104
+ value: true
105
+ pad_to_max_length:
106
+ desc: null
107
+ value: true
108
+ per_device_eval_batch_size:
109
+ desc: null
110
+ value: 250
111
+ per_device_train_batch_size:
112
+ desc: null
113
+ value: 250
114
+ preprocessing_num_workers:
115
+ desc: null
116
+ value: null
117
+ push_to_hub:
118
+ desc: null
119
+ value: true
120
+ save_steps:
121
+ desc: null
122
+ value: 1000
123
+ seed:
124
+ desc: null
125
+ value: 42
126
+ tokenizer_name:
127
+ desc: null
128
+ value: NbAiLab/nb-roberta-base
129
+ train_file:
130
+ desc: null
131
+ value: null
132
+ train_ref_file:
133
+ desc: null
134
+ value: null
135
+ use_fast_tokenizer:
136
+ desc: null
137
+ value: true
138
+ validation_file:
139
+ desc: null
140
+ value: null
141
+ validation_ref_file:
142
+ desc: null
143
+ value: null
144
+ validation_split_percentage:
145
+ desc: null
146
+ value: 5
147
+ warmup_steps:
148
+ desc: null
149
+ value: 10000
150
+ weight_decay:
151
+ desc: null
152
+ value: 0.01
wandb/run-20220114_221533-24dma583/files/diff.patch ADDED
File without changes
wandb/run-20220114_221533-24dma583/files/output.log ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2022-01-14 22:15:40.254500: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2
+ 2022-01-14 22:15:40.254546: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
3
+ Epoch ... (1/3): 0%| | 0/3 [00:00<?, ?it/s]
4
+ Training...: 0%| | 0/39919 [02:25<?, ?it/s]
5
+ Epoch ... (1/3): 0%| | 0/3 [03:13<?, ?it/s]
6
+ Traceback (most recent call last):
7
+ File "run_mlm_flax.py", line 815, in <module>
8
+ main()
9
+ File "run_mlm_flax.py", line 723, in main
10
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
11
+ File "/data/flax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
12
+ return fun(*args, **kwargs)
13
+ File "/data/flax/lib/python3.8/site-packages/jax/_src/api.py", line 2058, in cache_miss
14
+ out_tree, out_flat = f_pmapped_(*args, **kwargs)
15
+ File "/data/flax/lib/python3.8/site-packages/jax/_src/api.py", line 1934, in f_pmapped
16
+ out = pxla.xla_pmap(
17
+ File "/data/flax/lib/python3.8/site-packages/jax/core.py", line 1727, in bind
18
+ return call_bind(self, fun, *args, **params)
19
+ File "/data/flax/lib/python3.8/site-packages/jax/core.py", line 1652, in call_bind
20
+ outs = primitive.process(top_trace, fun, tracers, params)
21
+ File "/data/flax/lib/python3.8/site-packages/jax/core.py", line 1730, in process
22
+ return trace.process_map(self, fun, tracers, params)
23
+ File "/data/flax/lib/python3.8/site-packages/jax/core.py", line 633, in process_call
24
+ return primitive.impl(f, *tracers, **params)
25
+ File "/data/flax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 778, in xla_pmap_impl
26
+ return compiled_fun(*args)
27
+ File "/data/flax/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
28
+ return func(*args, **kwargs)
29
+ File "/data/flax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1502, in execute_replicated
30
+ out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
31
+ jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: RESOURCE_EXHAUSTED: Attempting to reserve 12.83G at the bottom of memory. That was not possible. There are 13.18G free, 0B reserved, and 12.71G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
32
+ The stack trace below excludes JAX-internal frames.
33
+ The preceding is the original exception that occurred, unmodified.
34
+ --------------------
35
+ The above exception was the direct cause of the following exception:
36
+ Traceback (most recent call last):
37
+ File "run_mlm_flax.py", line 815, in <module>
38
+ main()
39
+ File "run_mlm_flax.py", line 723, in main
40
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
41
+ File "/data/flax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1502, in execute_replicated
42
+ out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
43
+ RuntimeError: RESOURCE_EXHAUSTED: Attempting to reserve 12.83G at the bottom of memory. That was not possible. There are 13.18G free, 0B reserved, and 12.71G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
wandb/run-20220114_221533-24dma583/files/requirements.txt ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.8.1
3
+ aiosignal==1.2.0
4
+ astunparse==1.6.3
5
+ async-timeout==4.0.2
6
+ attrs==21.4.0
7
+ backcall==0.2.0
8
+ cachetools==4.2.4
9
+ certifi==2021.10.8
10
+ charset-normalizer==2.0.10
11
+ chex==0.1.0
12
+ click==8.0.3
13
+ clu==0.0.6
14
+ configparser==5.2.0
15
+ contextlib2==21.6.0
16
+ cycler==0.11.0
17
+ datasets==1.17.1.dev0
18
+ decorator==5.1.0
19
+ dill==0.3.4
20
+ dm-tree==0.1.6
21
+ docker-pycreds==0.4.0
22
+ filelock==3.4.2
23
+ flatbuffers==2.0
24
+ flax==0.3.6
25
+ fonttools==4.28.5
26
+ frozenlist==1.2.0
27
+ fsspec==2021.11.1
28
+ future==0.18.2
29
+ gast==0.4.0
30
+ gitdb==4.0.9
31
+ gitpython==3.1.26
32
+ google-auth-oauthlib==0.4.6
33
+ google-auth==2.3.3
34
+ google-pasta==0.2.0
35
+ googleapis-common-protos==1.54.0
36
+ grpcio==1.43.0
37
+ h5py==3.6.0
38
+ huggingface-hub==0.2.1
39
+ idna==3.3
40
+ importlib-metadata==4.10.0
41
+ importlib-resources==5.4.0
42
+ ipython==7.31.0
43
+ jax==0.2.26
44
+ jaxlib==0.1.75
45
+ jedi==0.18.1
46
+ joblib==1.1.0
47
+ keras-preprocessing==1.1.2
48
+ keras==2.7.0
49
+ kiwisolver==1.3.2
50
+ libclang==12.0.0
51
+ libtpu-nightly==0.1.dev20211208
52
+ markdown==3.3.6
53
+ matplotlib-inline==0.1.3
54
+ matplotlib==3.5.1
55
+ ml-collections==0.1.0
56
+ msgpack==1.0.3
57
+ multidict==5.2.0
58
+ multiprocess==0.70.12.2
59
+ numpy==1.22.0
60
+ oauthlib==3.1.1
61
+ opt-einsum==3.3.0
62
+ optax==0.1.0
63
+ packaging==21.3
64
+ pandas==1.3.5
65
+ parso==0.8.3
66
+ pathtools==0.1.2
67
+ pexpect==4.8.0
68
+ pickleshare==0.7.5
69
+ pillow==9.0.0
70
+ pip==20.0.2
71
+ pkg-resources==0.0.0
72
+ promise==2.3
73
+ prompt-toolkit==3.0.24
74
+ protobuf==3.19.1
75
+ psutil==5.9.0
76
+ ptyprocess==0.7.0
77
+ pyarrow==6.0.1
78
+ pyasn1-modules==0.2.8
79
+ pyasn1==0.4.8
80
+ pygments==2.11.1
81
+ pyparsing==3.0.6
82
+ python-dateutil==2.8.2
83
+ pytz==2021.3
84
+ pyyaml==6.0
85
+ regex==2021.11.10
86
+ requests-oauthlib==1.3.0
87
+ requests==2.27.0
88
+ rsa==4.8
89
+ sacremoses==0.0.46
90
+ scipy==1.7.3
91
+ sentry-sdk==1.5.2
92
+ setuptools==44.0.0
93
+ shortuuid==1.0.8
94
+ six==1.16.0
95
+ smmap==5.0.0
96
+ subprocess32==3.5.4
97
+ tensorboard-data-server==0.6.1
98
+ tensorboard-plugin-wit==1.8.0
99
+ tensorboard==2.7.0
100
+ tensorflow-cpu==2.7.0
101
+ tensorflow-datasets==4.4.0
102
+ tensorflow-estimator==2.7.0
103
+ tensorflow-io-gcs-filesystem==0.23.1
104
+ tensorflow-metadata==1.5.0
105
+ tensorflow==2.7.0
106
+ termcolor==1.1.0
107
+ tokenizers==0.11.2
108
+ toolz==0.11.2
109
+ tqdm==4.62.3
110
+ traitlets==5.1.1
111
+ transformers==4.16.0.dev0
112
+ typing-extensions==3.10.0.2
113
+ urllib3==1.26.7
114
+ wandb==0.12.9
115
+ wcwidth==0.2.5
116
+ werkzeug==2.0.2
117
+ wheel==0.37.1
118
+ wrapt==1.13.3
119
+ xxhash==2.0.2
120
+ yarl==1.7.2
121
+ yaspin==2.1.0
122
+ zipp==3.7.0
wandb/run-20220114_221533-24dma583/files/wandb-metadata.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2022-01-14T22:15:37.284889",
5
+ "startedAt": "2022-01-14T22:15:33.798491",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--output_dir=./",
11
+ "--model_type=roberta",
12
+ "--config_name=roberta-base",
13
+ "--tokenizer_name=NbAiLab/nb-roberta-base",
14
+ "--dataset_name=NbAiLab/NCC",
15
+ "--max_seq_length=128",
16
+ "--weight_decay=0.01",
17
+ "--per_device_train_batch_size=250",
18
+ "--per_device_eval_batch_size=250",
19
+ "--pad_to_max_length",
20
+ "--learning_rate=6e-4",
21
+ "--warmup_steps=10000",
22
+ "--overwrite_output_dir",
23
+ "--num_train_epochs=3",
24
+ "--adam_beta1=0.9",
25
+ "--adam_beta2=0.98",
26
+ "--adam_epsilon=1e-6",
27
+ "--logging_steps=1000",
28
+ "--save_steps=1000",
29
+ "--eval_steps=1000",
30
+ "--do_train",
31
+ "--do_eval",
32
+ "--dtype=bfloat16",
33
+ "--push_to_hub"
34
+ ],
35
+ "state": "running",
36
+ "program": "run_mlm_flax.py",
37
+ "codePath": "run_mlm_flax.py",
38
+ "git": {
39
+ "remote": "https://huggingface.co/versae/roberta-base-ncc",
40
+ "commit": "502df078f73cf93ca9380fcac1c9b9c7598a445f"
41
+ },
42
+ "email": "versae@gmail.com",
43
+ "root": "/data/roberta-base-ncc",
44
+ "host": "t1v-n-eedfb410-w-0",
45
+ "username": "javierr",
46
+ "executable": "/data/flax/bin/python"
47
+ }
wandb/run-20220114_221533-24dma583/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_wandb": {"runtime": 208}}
wandb/run-20220114_221533-24dma583/logs/debug-internal.log ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2022-01-14 22:15:34,709 INFO MainThread:7834 [internal.py:wandb_internal():87] W&B internal server running at pid: 7834, started at: 2022-01-14 22:15:34.709583
2
+ 2022-01-14 22:15:34,711 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: check_version
3
+ 2022-01-14 22:15:34,712 INFO WriterThread:7834 [datastore.py:open_for_write():77] open: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/run-24dma583.wandb
4
+ 2022-01-14 22:15:34,712 DEBUG SenderThread:7834 [sender.py:send():234] send: header
5
+ 2022-01-14 22:15:34,712 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: check_version
6
+ 2022-01-14 22:15:34,785 DEBUG SenderThread:7834 [sender.py:send():234] send: run
7
+ 2022-01-14 22:15:34,980 INFO SenderThread:7834 [dir_watcher.py:__init__():169] watching files in: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files
8
+ 2022-01-14 22:15:34,980 INFO SenderThread:7834 [sender.py:_start_run_threads():804] run started: 24dma583 with start time 1642198533
9
+ 2022-01-14 22:15:34,980 DEBUG SenderThread:7834 [sender.py:send():234] send: summary
10
+ 2022-01-14 22:15:34,980 INFO SenderThread:7834 [sender.py:_save_file():939] saving file wandb-summary.json with policy end
11
+ 2022-01-14 22:15:34,981 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: run_start
12
+ 2022-01-14 22:15:35,985 INFO Thread-8 :7834 [dir_watcher.py:_on_file_created():217] file/dir created: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/wandb-summary.json
13
+ 2022-01-14 22:15:37,284 DEBUG HandlerThread:7834 [meta.py:__init__():40] meta init
14
+ 2022-01-14 22:15:37,284 DEBUG HandlerThread:7834 [meta.py:__init__():54] meta init done
15
+ 2022-01-14 22:15:37,284 DEBUG HandlerThread:7834 [meta.py:probe():214] probe
16
+ 2022-01-14 22:15:37,286 DEBUG HandlerThread:7834 [meta.py:_setup_git():204] setup git
17
+ 2022-01-14 22:15:37,315 DEBUG HandlerThread:7834 [meta.py:_setup_git():211] setup git done
18
+ 2022-01-14 22:15:37,315 DEBUG HandlerThread:7834 [meta.py:_save_code():92] save code
19
+ 2022-01-14 22:15:37,326 DEBUG HandlerThread:7834 [meta.py:_save_code():113] save code done
20
+ 2022-01-14 22:15:37,326 DEBUG HandlerThread:7834 [meta.py:_save_patches():130] save patches
21
+ 2022-01-14 22:15:37,985 INFO Thread-8 :7834 [dir_watcher.py:_on_file_created():217] file/dir created: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/code/run_mlm_flax.py
22
+ 2022-01-14 22:15:37,986 INFO Thread-8 :7834 [dir_watcher.py:_on_file_created():217] file/dir created: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/code
23
+ 2022-01-14 22:15:39,986 INFO Thread-8 :7834 [dir_watcher.py:_on_file_created():217] file/dir created: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/output.log
24
+ 2022-01-14 22:15:40,986 INFO Thread-8 :7834 [dir_watcher.py:_on_file_created():217] file/dir created: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/diff.patch
25
+ 2022-01-14 22:15:42,987 INFO Thread-8 :7834 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/output.log
26
+ 2022-01-14 22:15:45,607 ERROR HandlerThread:7834 [meta.py:_save_patches():171] Error generating diff: Command '['git', 'diff', '--submodule=diff', 'HEAD']' timed out after 5 seconds
27
+ 2022-01-14 22:15:45,607 DEBUG HandlerThread:7834 [meta.py:_save_patches():172] save patches done
28
+ 2022-01-14 22:15:45,607 DEBUG HandlerThread:7834 [meta.py:_save_pip():58] save pip
29
+ 2022-01-14 22:15:45,607 DEBUG HandlerThread:7834 [meta.py:_save_pip():72] save pip done
30
+ 2022-01-14 22:15:45,608 DEBUG HandlerThread:7834 [meta.py:probe():252] probe done
31
+ 2022-01-14 22:15:45,643 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: stop_status
32
+ 2022-01-14 22:15:45,643 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: stop_status
33
+ 2022-01-14 22:15:45,786 DEBUG SenderThread:7834 [sender.py:send():234] send: config
34
+ 2022-01-14 22:15:45,787 DEBUG SenderThread:7834 [sender.py:send():234] send: config
35
+ 2022-01-14 22:15:45,787 DEBUG SenderThread:7834 [sender.py:send():234] send: config
36
+ 2022-01-14 22:15:45,787 DEBUG SenderThread:7834 [sender.py:send():234] send: files
37
+ 2022-01-14 22:15:45,787 INFO SenderThread:7834 [sender.py:_save_file():939] saving file wandb-metadata.json with policy now
38
+ 2022-01-14 22:15:45,789 INFO SenderThread:7834 [sender.py:_save_file():939] saving file code/run_mlm_flax.py with policy now
39
+ 2022-01-14 22:15:45,988 INFO Thread-8 :7834 [dir_watcher.py:_on_file_created():217] file/dir created: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/wandb-metadata.json
40
+ 2022-01-14 22:15:45,989 INFO Thread-8 :7834 [dir_watcher.py:_on_file_created():217] file/dir created: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/requirements.txt
41
+ 2022-01-14 22:15:46,312 INFO Thread-12 :7834 [upload_job.py:push():137] Uploaded file /tmp/tmpxqv1l1fswandb/2juok80v-code/run_mlm_flax.py
42
+ 2022-01-14 22:15:46,330 INFO Thread-11 :7834 [upload_job.py:push():137] Uploaded file /tmp/tmpxqv1l1fswandb/xnc44171-wandb-metadata.json
43
+ 2022-01-14 22:15:50,990 INFO Thread-8 :7834 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/output.log
44
+ 2022-01-14 22:15:59,991 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: stop_status
45
+ 2022-01-14 22:15:59,991 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: stop_status
46
+ 2022-01-14 22:16:05,368 DEBUG SenderThread:7834 [sender.py:send():234] send: stats
47
+ 2022-01-14 22:16:05,997 INFO Thread-8 :7834 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/config.yaml
48
+ 2022-01-14 22:16:15,132 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: stop_status
49
+ 2022-01-14 22:16:15,132 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: stop_status
50
+ 2022-01-14 22:16:30,272 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: stop_status
51
+ 2022-01-14 22:16:30,272 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: stop_status
52
+ 2022-01-14 22:16:35,439 DEBUG SenderThread:7834 [sender.py:send():234] send: stats
53
+ 2022-01-14 22:16:39,009 INFO Thread-8 :7834 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/output.log
54
+ 2022-01-14 22:16:45,408 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: stop_status
55
+ 2022-01-14 22:16:45,408 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: stop_status
56
+ 2022-01-14 22:17:00,601 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: stop_status
57
+ 2022-01-14 22:17:00,601 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: stop_status
58
+ 2022-01-14 22:17:05,512 DEBUG SenderThread:7834 [sender.py:send():234] send: stats
59
+ 2022-01-14 22:17:15,756 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: stop_status
60
+ 2022-01-14 22:17:15,756 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: stop_status
61
+ 2022-01-14 22:17:30,970 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: stop_status
62
+ 2022-01-14 22:17:30,971 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: stop_status
63
+ 2022-01-14 22:17:35,586 DEBUG SenderThread:7834 [sender.py:send():234] send: stats
64
+ 2022-01-14 22:17:46,135 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: stop_status
65
+ 2022-01-14 22:17:46,136 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: stop_status
66
+ 2022-01-14 22:18:01,309 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: stop_status
67
+ 2022-01-14 22:18:01,309 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: stop_status
68
+ 2022-01-14 22:18:05,663 DEBUG SenderThread:7834 [sender.py:send():234] send: stats
69
+ 2022-01-14 22:18:16,458 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: stop_status
70
+ 2022-01-14 22:18:16,458 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: stop_status
71
+ 2022-01-14 22:18:31,596 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: stop_status
72
+ 2022-01-14 22:18:31,597 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: stop_status
73
+ 2022-01-14 22:18:35,742 DEBUG SenderThread:7834 [sender.py:send():234] send: stats
74
+ 2022-01-14 22:18:46,731 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: stop_status
75
+ 2022-01-14 22:18:46,732 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: stop_status
76
+ 2022-01-14 22:19:03,067 INFO Thread-8 :7834 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/output.log
77
+ 2022-01-14 22:19:03,953 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: poll_exit
78
+ 2022-01-14 22:19:03,953 DEBUG SenderThread:7834 [sender.py:send():234] send: telemetry
79
+ 2022-01-14 22:19:03,953 DEBUG SenderThread:7834 [sender.py:send():234] send: exit
80
+ 2022-01-14 22:19:03,953 INFO SenderThread:7834 [sender.py:send_exit():366] handling exit code: 1
81
+ 2022-01-14 22:19:03,954 INFO SenderThread:7834 [sender.py:send_exit():368] handling runtime: 208
82
+ 2022-01-14 22:19:03,954 INFO SenderThread:7834 [sender.py:_save_file():939] saving file wandb-summary.json with policy end
83
+ 2022-01-14 22:19:03,954 INFO SenderThread:7834 [sender.py:send_exit():374] send defer
84
+ 2022-01-14 22:19:03,954 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: poll_exit
85
+ 2022-01-14 22:19:03,955 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: defer
86
+ 2022-01-14 22:19:03,955 INFO HandlerThread:7834 [handler.py:handle_request_defer():147] handle defer: 0
87
+ 2022-01-14 22:19:03,955 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: defer
88
+ 2022-01-14 22:19:03,955 INFO SenderThread:7834 [sender.py:send_request_defer():383] handle sender defer: 0
89
+ 2022-01-14 22:19:03,955 INFO SenderThread:7834 [sender.py:transition_state():387] send defer: 1
90
+ 2022-01-14 22:19:03,955 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: defer
91
+ 2022-01-14 22:19:03,955 INFO HandlerThread:7834 [handler.py:handle_request_defer():147] handle defer: 1
92
+ 2022-01-14 22:19:04,011 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: defer
93
+ 2022-01-14 22:19:04,011 INFO SenderThread:7834 [sender.py:send_request_defer():383] handle sender defer: 1
94
+ 2022-01-14 22:19:04,011 INFO SenderThread:7834 [sender.py:transition_state():387] send defer: 2
95
+ 2022-01-14 22:19:04,011 DEBUG SenderThread:7834 [sender.py:send():234] send: stats
96
+ 2022-01-14 22:19:04,012 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: defer
97
+ 2022-01-14 22:19:04,012 INFO HandlerThread:7834 [handler.py:handle_request_defer():147] handle defer: 2
98
+ 2022-01-14 22:19:04,012 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: defer
99
+ 2022-01-14 22:19:04,012 INFO SenderThread:7834 [sender.py:send_request_defer():383] handle sender defer: 2
100
+ 2022-01-14 22:19:04,012 INFO SenderThread:7834 [sender.py:transition_state():387] send defer: 3
101
+ 2022-01-14 22:19:04,012 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: defer
102
+ 2022-01-14 22:19:04,012 INFO HandlerThread:7834 [handler.py:handle_request_defer():147] handle defer: 3
103
+ 2022-01-14 22:19:04,012 DEBUG SenderThread:7834 [sender.py:send():234] send: summary
104
+ 2022-01-14 22:19:04,013 INFO SenderThread:7834 [sender.py:_save_file():939] saving file wandb-summary.json with policy end
105
+ 2022-01-14 22:19:04,013 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: defer
106
+ 2022-01-14 22:19:04,013 INFO SenderThread:7834 [sender.py:send_request_defer():383] handle sender defer: 3
107
+ 2022-01-14 22:19:04,013 INFO SenderThread:7834 [sender.py:transition_state():387] send defer: 4
108
+ 2022-01-14 22:19:04,013 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: defer
109
+ 2022-01-14 22:19:04,013 INFO HandlerThread:7834 [handler.py:handle_request_defer():147] handle defer: 4
110
+ 2022-01-14 22:19:04,013 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: defer
111
+ 2022-01-14 22:19:04,013 INFO SenderThread:7834 [sender.py:send_request_defer():383] handle sender defer: 4
112
+ 2022-01-14 22:19:04,057 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: poll_exit
113
+ 2022-01-14 22:19:04,068 INFO Thread-8 :7834 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/output.log
114
+ 2022-01-14 22:19:04,068 INFO Thread-8 :7834 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/wandb-summary.json
115
+ 2022-01-14 22:19:04,198 INFO SenderThread:7834 [sender.py:transition_state():387] send defer: 5
116
+ 2022-01-14 22:19:04,198 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: poll_exit
117
+ 2022-01-14 22:19:04,198 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: defer
118
+ 2022-01-14 22:19:04,198 INFO HandlerThread:7834 [handler.py:handle_request_defer():147] handle defer: 5
119
+ 2022-01-14 22:19:04,199 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: defer
120
+ 2022-01-14 22:19:04,199 INFO SenderThread:7834 [sender.py:send_request_defer():383] handle sender defer: 5
121
+ 2022-01-14 22:19:04,199 INFO SenderThread:7834 [dir_watcher.py:finish():283] shutting down directory watcher
122
+ 2022-01-14 22:19:04,300 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: poll_exit
123
+ 2022-01-14 22:19:05,068 INFO Thread-8 :7834 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/config.yaml
124
+ 2022-01-14 22:19:05,069 INFO SenderThread:7834 [dir_watcher.py:finish():313] scan: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files
125
+ 2022-01-14 22:19:05,069 INFO SenderThread:7834 [dir_watcher.py:finish():327] scan save: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/config.yaml config.yaml
126
+ 2022-01-14 22:19:05,069 INFO SenderThread:7834 [dir_watcher.py:finish():327] scan save: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/diff.patch diff.patch
127
+ 2022-01-14 22:19:05,069 INFO SenderThread:7834 [dir_watcher.py:finish():327] scan save: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/requirements.txt requirements.txt
128
+ 2022-01-14 22:19:05,069 INFO SenderThread:7834 [dir_watcher.py:finish():327] scan save: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/output.log output.log
129
+ 2022-01-14 22:19:05,070 INFO SenderThread:7834 [dir_watcher.py:finish():327] scan save: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/wandb-summary.json wandb-summary.json
130
+ 2022-01-14 22:19:05,070 INFO SenderThread:7834 [dir_watcher.py:finish():327] scan save: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/wandb-metadata.json wandb-metadata.json
131
+ 2022-01-14 22:19:05,072 INFO SenderThread:7834 [dir_watcher.py:finish():327] scan save: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/code/run_mlm_flax.py code/run_mlm_flax.py
132
+ 2022-01-14 22:19:05,073 INFO SenderThread:7834 [sender.py:transition_state():387] send defer: 6
133
+ 2022-01-14 22:19:05,073 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: poll_exit
134
+ 2022-01-14 22:19:05,081 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: defer
135
+ 2022-01-14 22:19:05,081 INFO HandlerThread:7834 [handler.py:handle_request_defer():147] handle defer: 6
136
+ 2022-01-14 22:19:05,081 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: defer
137
+ 2022-01-14 22:19:05,081 INFO SenderThread:7834 [sender.py:send_request_defer():383] handle sender defer: 6
138
+ 2022-01-14 22:19:05,081 INFO SenderThread:7834 [file_pusher.py:finish():177] shutting down file pusher
139
+ 2022-01-14 22:19:05,183 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: poll_exit
140
+ 2022-01-14 22:19:05,183 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: poll_exit
141
+ 2022-01-14 22:19:05,285 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: poll_exit
142
+ 2022-01-14 22:19:05,285 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: poll_exit
143
+ 2022-01-14 22:19:05,387 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: poll_exit
144
+ 2022-01-14 22:19:05,387 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: poll_exit
145
+ 2022-01-14 22:19:05,488 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: poll_exit
146
+ 2022-01-14 22:19:05,489 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: poll_exit
147
+ 2022-01-14 22:19:05,539 INFO Thread-13 :7834 [upload_job.py:push():137] Uploaded file /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/config.yaml
148
+ 2022-01-14 22:19:05,556 INFO Thread-15 :7834 [upload_job.py:push():137] Uploaded file /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/output.log
149
+ 2022-01-14 22:19:05,561 INFO Thread-14 :7834 [upload_job.py:push():137] Uploaded file /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/requirements.txt
150
+ 2022-01-14 22:19:05,590 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: poll_exit
151
+ 2022-01-14 22:19:05,590 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: poll_exit
152
+ 2022-01-14 22:19:05,599 INFO Thread-16 :7834 [upload_job.py:push():137] Uploaded file /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/files/wandb-summary.json
153
+ 2022-01-14 22:19:05,692 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: poll_exit
154
+ 2022-01-14 22:19:05,692 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: poll_exit
155
+ 2022-01-14 22:19:05,794 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: poll_exit
156
+ 2022-01-14 22:19:05,794 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: poll_exit
157
+ 2022-01-14 22:19:05,799 INFO Thread-7 :7834 [sender.py:transition_state():387] send defer: 7
158
+ 2022-01-14 22:19:05,800 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: defer
159
+ 2022-01-14 22:19:05,800 INFO HandlerThread:7834 [handler.py:handle_request_defer():147] handle defer: 7
160
+ 2022-01-14 22:19:05,800 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: defer
161
+ 2022-01-14 22:19:05,800 INFO SenderThread:7834 [sender.py:send_request_defer():383] handle sender defer: 7
162
+ 2022-01-14 22:19:05,896 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: poll_exit
163
+ 2022-01-14 22:19:06,218 INFO SenderThread:7834 [sender.py:transition_state():387] send defer: 8
164
+ 2022-01-14 22:19:06,218 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: poll_exit
165
+ 2022-01-14 22:19:06,218 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: defer
166
+ 2022-01-14 22:19:06,218 INFO HandlerThread:7834 [handler.py:handle_request_defer():147] handle defer: 8
167
+ 2022-01-14 22:19:06,219 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: defer
168
+ 2022-01-14 22:19:06,219 INFO SenderThread:7834 [sender.py:send_request_defer():383] handle sender defer: 8
169
+ 2022-01-14 22:19:06,219 INFO SenderThread:7834 [sender.py:transition_state():387] send defer: 9
170
+ 2022-01-14 22:19:06,219 DEBUG SenderThread:7834 [sender.py:send():234] send: final
171
+ 2022-01-14 22:19:06,219 DEBUG SenderThread:7834 [sender.py:send():234] send: footer
172
+ 2022-01-14 22:19:06,219 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: defer
173
+ 2022-01-14 22:19:06,220 INFO HandlerThread:7834 [handler.py:handle_request_defer():147] handle defer: 9
174
+ 2022-01-14 22:19:06,220 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: defer
175
+ 2022-01-14 22:19:06,220 INFO SenderThread:7834 [sender.py:send_request_defer():383] handle sender defer: 9
176
+ 2022-01-14 22:19:06,320 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: poll_exit
177
+ 2022-01-14 22:19:06,320 DEBUG SenderThread:7834 [sender.py:send_request():248] send_request: poll_exit
178
+ 2022-01-14 22:19:06,320 INFO SenderThread:7834 [file_pusher.py:join():182] waiting for file pusher
179
+ 2022-01-14 22:19:06,598 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: get_summary
180
+ 2022-01-14 22:19:06,618 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: sampled_history
181
+ 2022-01-14 22:19:06,619 DEBUG HandlerThread:7834 [handler.py:handle_request():130] handle_request: shutdown
182
+ 2022-01-14 22:19:06,619 INFO HandlerThread:7834 [handler.py:finish():731] shutting down handler
183
+ 2022-01-14 22:19:07,220 INFO WriterThread:7834 [datastore.py:close():281] close: /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/run-24dma583.wandb
184
+ 2022-01-14 22:19:07,575 INFO SenderThread:7834 [sender.py:finish():1070] shutting down sender
185
+ 2022-01-14 22:19:07,575 INFO SenderThread:7834 [file_pusher.py:finish():177] shutting down file pusher
186
+ 2022-01-14 22:19:07,576 INFO SenderThread:7834 [file_pusher.py:join():182] waiting for file pusher
187
+ 2022-01-14 22:19:07,578 INFO MainThread:7834 [internal.py:handle_exit():77] Internal process exited
wandb/run-20220114_221533-24dma583/logs/debug.log ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2022-01-14 22:15:33,825 INFO MainThread:4503 [wandb_setup.py:_flush():71] setting env: {}
2
+ 2022-01-14 22:15:33,826 INFO MainThread:4503 [wandb_setup.py:_flush():71] setting login settings: {}
3
+ 2022-01-14 22:15:33,826 INFO MainThread:4503 [wandb_init.py:_log_setup():371] Logging user logs to /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/logs/debug.log
4
+ 2022-01-14 22:15:33,826 INFO MainThread:4503 [wandb_init.py:_log_setup():372] Logging internal logs to /data/roberta-base-ncc/wandb/run-20220114_221533-24dma583/logs/debug-internal.log
5
+ 2022-01-14 22:15:33,826 INFO MainThread:4503 [wandb_init.py:init():404] calling init triggers
6
+ 2022-01-14 22:15:33,826 INFO MainThread:4503 [wandb_init.py:init():409] wandb.init called with sweep_config: {}
7
+ config: {}
8
+ 2022-01-14 22:15:33,826 INFO MainThread:4503 [wandb_init.py:init():460] starting backend
9
+ 2022-01-14 22:15:33,826 INFO MainThread:4503 [backend.py:_multiprocessing_setup():99] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
10
+ 2022-01-14 22:15:33,871 INFO MainThread:4503 [backend.py:ensure_launched():216] starting backend process...
11
+ 2022-01-14 22:15:33,898 INFO MainThread:4503 [backend.py:ensure_launched():221] started backend process with pid: 7834
12
+ 2022-01-14 22:15:33,900 INFO MainThread:4503 [wandb_init.py:init():469] backend started and connected
13
+ 2022-01-14 22:15:33,911 INFO MainThread:4503 [wandb_init.py:init():533] updated telemetry
14
+ 2022-01-14 22:15:33,976 INFO MainThread:4503 [wandb_init.py:init():563] communicating current version
15
+ 2022-01-14 22:15:34,784 INFO MainThread:4503 [wandb_init.py:init():568] got version response
16
+ 2022-01-14 22:15:34,784 INFO MainThread:4503 [wandb_init.py:init():578] communicating run to backend with 30 second timeout
17
+ 2022-01-14 22:15:34,980 INFO MainThread:4503 [wandb_init.py:init():606] starting run threads in backend
18
+ 2022-01-14 22:15:39,985 INFO MainThread:4503 [wandb_run.py:_console_start():1810] atexit reg
19
+ 2022-01-14 22:15:39,985 INFO MainThread:4503 [wandb_run.py:_redirect():1684] redirect: SettingsConsole.REDIRECT
20
+ 2022-01-14 22:15:39,986 INFO MainThread:4503 [wandb_run.py:_redirect():1689] Redirecting console.
21
+ 2022-01-14 22:15:39,988 INFO MainThread:4503 [wandb_run.py:_redirect():1745] Redirects installed.
22
+ 2022-01-14 22:15:39,988 INFO MainThread:4503 [wandb_init.py:init():633] run started, returning control to user process
23
+ 2022-01-14 22:15:39,989 INFO MainThread:4503 [wandb_run.py:_config_callback():956] config_cb None None {'output_dir': './', 'overwrite_output_dir': True, 'do_train': True, 'do_eval': True, 'per_device_train_batch_size': 250, 'per_device_eval_batch_size': 250, 'learning_rate': 0.0006, 'weight_decay': 0.01, 'adam_beta1': 0.9, 'adam_beta2': 0.98, 'adam_epsilon': 1e-06, 'adafactor': False, 'num_train_epochs': 3.0, 'warmup_steps': 10000, 'logging_steps': 1000, 'save_steps': 1000, 'eval_steps': 1000, 'seed': 42, 'push_to_hub': True, 'hub_model_id': None, 'hub_token': None}
24
+ 2022-01-14 22:15:39,989 INFO MainThread:4503 [wandb_run.py:_config_callback():956] config_cb None None {'model_name_or_path': None, 'model_type': 'roberta', 'config_name': 'roberta-base', 'tokenizer_name': 'NbAiLab/nb-roberta-base', 'cache_dir': None, 'use_fast_tokenizer': True, 'dtype': 'bfloat16'}
25
+ 2022-01-14 22:15:39,990 INFO MainThread:4503 [wandb_run.py:_config_callback():956] config_cb None None {'dataset_name': 'NbAiLab/NCC', 'dataset_config_name': None, 'train_file': None, 'validation_file': None, 'train_ref_file': None, 'validation_ref_file': None, 'overwrite_cache': False, 'validation_split_percentage': 5, 'max_seq_length': 128, 'preprocessing_num_workers': None, 'mlm_probability': 0.15, 'pad_to_max_length': True, 'line_by_line': False}
26
+ 2022-01-14 22:19:01,641 INFO MainThread:4503 [wandb_run.py:_atexit_cleanup():1780] got exitcode: 1
27
+ 2022-01-14 22:19:01,645 INFO MainThread:4503 [wandb_run.py:_restore():1752] restore
28
+ 2022-01-14 22:19:03,955 INFO MainThread:4503 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
29
+ wandb_count: 1
30
+ other_count: 1
31
+ }
32
+ pusher_stats {
33
+ uploaded_bytes: 37446
34
+ total_bytes: 37446
35
+ }
36
+
37
+ 2022-01-14 22:19:04,199 INFO MainThread:4503 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
38
+ wandb_count: 1
39
+ other_count: 1
40
+ }
41
+ pusher_stats {
42
+ uploaded_bytes: 37446
43
+ total_bytes: 37446
44
+ }
45
+
46
+ 2022-01-14 22:19:05,082 INFO MainThread:4503 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
47
+ wandb_count: 5
48
+ other_count: 1
49
+ }
50
+ pusher_stats {
51
+ uploaded_bytes: 37446
52
+ total_bytes: 45535
53
+ }
54
+
55
+ 2022-01-14 22:19:05,184 INFO MainThread:4503 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
56
+ wandb_count: 5
57
+ other_count: 1
58
+ }
59
+ pusher_stats {
60
+ uploaded_bytes: 37446
61
+ total_bytes: 45535
62
+ }
63
+
64
+ 2022-01-14 22:19:05,286 INFO MainThread:4503 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
65
+ wandb_count: 5
66
+ other_count: 1
67
+ }
68
+ pusher_stats {
69
+ uploaded_bytes: 45535
70
+ total_bytes: 45535
71
+ }
72
+
73
+ 2022-01-14 22:19:05,387 INFO MainThread:4503 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
74
+ wandb_count: 5
75
+ other_count: 1
76
+ }
77
+ pusher_stats {
78
+ uploaded_bytes: 45535
79
+ total_bytes: 45535
80
+ }
81
+
82
+ 2022-01-14 22:19:05,489 INFO MainThread:4503 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
83
+ wandb_count: 5
84
+ other_count: 1
85
+ }
86
+ pusher_stats {
87
+ uploaded_bytes: 45535
88
+ total_bytes: 45535
89
+ }
90
+
91
+ 2022-01-14 22:19:05,591 INFO MainThread:4503 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
92
+ wandb_count: 5
93
+ other_count: 1
94
+ }
95
+ pusher_stats {
96
+ uploaded_bytes: 45535
97
+ total_bytes: 45535
98
+ }
99
+
100
+ 2022-01-14 22:19:05,693 INFO MainThread:4503 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
101
+ wandb_count: 5
102
+ other_count: 1
103
+ }
104
+ pusher_stats {
105
+ uploaded_bytes: 45535
106
+ total_bytes: 45535
107
+ }
108
+
109
+ 2022-01-14 22:19:05,795 INFO MainThread:4503 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
110
+ wandb_count: 5
111
+ other_count: 1
112
+ }
113
+ pusher_stats {
114
+ uploaded_bytes: 45535
115
+ total_bytes: 45535
116
+ }
117
+
118
+ 2022-01-14 22:19:06,219 INFO MainThread:4503 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
119
+ wandb_count: 5
120
+ other_count: 1
121
+ }
122
+ pusher_stats {
123
+ uploaded_bytes: 45535
124
+ total_bytes: 45535
125
+ }
126
+
127
+ 2022-01-14 22:19:06,576 INFO MainThread:4503 [wandb_run.py:_wait_for_finish():1912] got exit ret: done: true
128
+ exit_result {
129
+ }
130
+ file_counts {
131
+ wandb_count: 5
132
+ other_count: 1
133
+ }
134
+ pusher_stats {
135
+ uploaded_bytes: 45535
136
+ total_bytes: 45535
137
+ }
138
+ local_info {
139
+ }
140
+
141
+ 2022-01-14 22:19:09,886 INFO MainThread:4503 [wandb_run.py:_append_files():2180] logging synced files
wandb/run-20220114_221533-24dma583/run-24dma583.wandb ADDED
Binary file (7.18 kB). View file
 
wandb/run-20220114_234119-1zya86oe/files/code/run_mlm_flax.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=fill-mask
22
+ """
23
+ import json
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ import time
29
+ from dataclasses import asdict, dataclass, field
30
+ from enum import Enum
31
+ from itertools import chain
32
+
33
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
34
+ from pathlib import Path
35
+ from typing import Dict, List, Optional, Tuple
36
+
37
+ import numpy as np
38
+ from datasets import load_dataset
39
+ from tqdm import tqdm
40
+
41
+ import flax
42
+ import jax
43
+ import jax.numpy as jnp
44
+ import optax
45
+ from flax import jax_utils, traverse_util
46
+ from flax.training import train_state
47
+ from flax.training.common_utils import get_metrics, onehot, shard
48
+ from huggingface_hub import Repository
49
+ from transformers import (
50
+ CONFIG_MAPPING,
51
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
52
+ AutoConfig,
53
+ AutoTokenizer,
54
+ FlaxAutoModelForMaskedLM,
55
+ HfArgumentParser,
56
+ PreTrainedTokenizerBase,
57
+ TensorType,
58
+ is_tensorboard_available,
59
+ set_seed,
60
+ )
61
+ from transformers.file_utils import get_full_repo_name
62
+
63
+
64
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
65
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
66
+
67
+
68
+ @dataclass
69
+ class TrainingArguments:
70
+ output_dir: str = field(
71
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
72
+ )
73
+ overwrite_output_dir: bool = field(
74
+ default=False,
75
+ metadata={
76
+ "help": (
77
+ "Overwrite the content of the output directory. "
78
+ "Use this to continue training if output_dir points to a checkpoint directory."
79
+ )
80
+ },
81
+ )
82
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
83
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
84
+ per_device_train_batch_size: int = field(
85
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
86
+ )
87
+ per_device_eval_batch_size: int = field(
88
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
89
+ )
90
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
91
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
92
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
93
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
94
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
95
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
96
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
97
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
98
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
99
+ save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
100
+ eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
101
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
102
+ push_to_hub: bool = field(
103
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
104
+ )
105
+ hub_model_id: str = field(
106
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
107
+ )
108
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
109
+
110
+ def __post_init__(self):
111
+ if self.output_dir is not None:
112
+ self.output_dir = os.path.expanduser(self.output_dir)
113
+
114
+ def to_dict(self):
115
+ """
116
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
117
+ the token values by removing their value.
118
+ """
119
+ d = asdict(self)
120
+ for k, v in d.items():
121
+ if isinstance(v, Enum):
122
+ d[k] = v.value
123
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
124
+ d[k] = [x.value for x in v]
125
+ if k.endswith("_token"):
126
+ d[k] = f"<{k.upper()}>"
127
+ return d
128
+
129
+
130
+ @dataclass
131
+ class ModelArguments:
132
+ """
133
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
134
+ """
135
+
136
+ model_name_or_path: Optional[str] = field(
137
+ default=None,
138
+ metadata={
139
+ "help": "The model checkpoint for weights initialization."
140
+ "Don't set if you want to train a model from scratch."
141
+ },
142
+ )
143
+ model_type: Optional[str] = field(
144
+ default=None,
145
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
146
+ )
147
+ config_name: Optional[str] = field(
148
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
149
+ )
150
+ tokenizer_name: Optional[str] = field(
151
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
152
+ )
153
+ cache_dir: Optional[str] = field(
154
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
155
+ )
156
+ use_fast_tokenizer: bool = field(
157
+ default=True,
158
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
159
+ )
160
+ dtype: Optional[str] = field(
161
+ default="float32",
162
+ metadata={
163
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
164
+ },
165
+ )
166
+
167
+
168
+ @dataclass
169
+ class DataTrainingArguments:
170
+ """
171
+ Arguments pertaining to what data we are going to input our model for training and eval.
172
+ """
173
+
174
+ dataset_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
176
+ )
177
+ dataset_config_name: Optional[str] = field(
178
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
179
+ )
180
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
181
+ validation_file: Optional[str] = field(
182
+ default=None,
183
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
184
+ )
185
+ train_ref_file: Optional[str] = field(
186
+ default=None,
187
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
188
+ )
189
+ validation_ref_file: Optional[str] = field(
190
+ default=None,
191
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
192
+ )
193
+ overwrite_cache: bool = field(
194
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
195
+ )
196
+ validation_split_percentage: Optional[int] = field(
197
+ default=5,
198
+ metadata={
199
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
200
+ },
201
+ )
202
+ max_seq_length: Optional[int] = field(
203
+ default=None,
204
+ metadata={
205
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
206
+ "than this will be truncated. Default to the max input length of the model."
207
+ },
208
+ )
209
+ preprocessing_num_workers: Optional[int] = field(
210
+ default=None,
211
+ metadata={"help": "The number of processes to use for the preprocessing."},
212
+ )
213
+ mlm_probability: float = field(
214
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
215
+ )
216
+ pad_to_max_length: bool = field(
217
+ default=False,
218
+ metadata={
219
+ "help": "Whether to pad all samples to `max_seq_length`. "
220
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
221
+ },
222
+ )
223
+ line_by_line: bool = field(
224
+ default=False,
225
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
226
+ )
227
+
228
+ def __post_init__(self):
229
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
230
+ raise ValueError("Need either a dataset name or a training/validation file.")
231
+ else:
232
+ if self.train_file is not None:
233
+ extension = self.train_file.split(".")[-1]
234
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
235
+ if self.validation_file is not None:
236
+ extension = self.validation_file.split(".")[-1]
237
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
238
+
239
+
240
+ @flax.struct.dataclass
241
+ class FlaxDataCollatorForLanguageModeling:
242
+ """
243
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
244
+ are not all of the same length.
245
+
246
+ Args:
247
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
248
+ The tokenizer used for encoding the data.
249
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
250
+ The probability with which to (randomly) mask tokens in the input.
251
+
252
+ .. note::
253
+
254
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
255
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
256
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
257
+ argument :obj:`return_special_tokens_mask=True`.
258
+ """
259
+
260
+ tokenizer: PreTrainedTokenizerBase
261
+ mlm_probability: float = 0.15
262
+
263
+ def __post_init__(self):
264
+ if self.tokenizer.mask_token is None:
265
+ raise ValueError(
266
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
267
+ "You should pass `mlm=False` to train on causal language modeling instead."
268
+ )
269
+
270
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
271
+ # Handle dict or lists with proper padding and conversion to tensor.
272
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
273
+
274
+ # If special token mask has been preprocessed, pop it from the dict.
275
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
276
+
277
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
278
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
279
+ )
280
+ return batch
281
+
282
+ def mask_tokens(
283
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
284
+ ) -> Tuple[np.ndarray, np.ndarray]:
285
+ """
286
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
287
+ """
288
+ labels = inputs.copy()
289
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
290
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
291
+ special_tokens_mask = special_tokens_mask.astype("bool")
292
+
293
+ probability_matrix[special_tokens_mask] = 0.0
294
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
295
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
296
+
297
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
298
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
299
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
300
+
301
+ # 10% of the time, we replace masked input tokens with random word
302
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
303
+ indices_random &= masked_indices & ~indices_replaced
304
+
305
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
306
+ inputs[indices_random] = random_words[indices_random]
307
+
308
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
309
+ return inputs, labels
310
+
311
+
312
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
313
+ num_samples = len(samples_idx)
314
+ samples_to_remove = num_samples % batch_size
315
+
316
+ if samples_to_remove != 0:
317
+ samples_idx = samples_idx[:-samples_to_remove]
318
+ sections_split = num_samples // batch_size
319
+ batch_idx = np.split(samples_idx, sections_split)
320
+ return batch_idx
321
+
322
+
323
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
324
+ summary_writer.scalar("train_time", train_time, step)
325
+
326
+ train_metrics = get_metrics(train_metrics)
327
+ for key, vals in train_metrics.items():
328
+ tag = f"train_{key}"
329
+ for i, val in enumerate(vals):
330
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
331
+
332
+
333
+ def write_eval_metric(summary_writer, eval_metrics, step):
334
+ for metric_name, value in eval_metrics.items():
335
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
336
+
337
+
338
+ def main():
339
+ # See all possible arguments in src/transformers/training_args.py
340
+ # or by passing the --help flag to this script.
341
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
342
+
343
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
344
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
345
+ # If we pass only one argument to the script and it's the path to a json file,
346
+ # let's parse it to get our arguments.
347
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
348
+ else:
349
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
350
+
351
+ if (
352
+ os.path.exists(training_args.output_dir)
353
+ and os.listdir(training_args.output_dir)
354
+ and training_args.do_train
355
+ and not training_args.overwrite_output_dir
356
+ ):
357
+ raise ValueError(
358
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
359
+ "Use --overwrite_output_dir to overcome."
360
+ )
361
+
362
+ # Setup logging
363
+ logging.basicConfig(
364
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
365
+ level=logging.INFO,
366
+ datefmt="[%X]",
367
+ )
368
+
369
+ # Log on each process the small summary:
370
+ logger = logging.getLogger(__name__)
371
+
372
+ # Set the verbosity to info of the Transformers logger (on main process only):
373
+ logger.info(f"Training/evaluation parameters {training_args}")
374
+
375
+ # Set seed before initializing model.
376
+ set_seed(training_args.seed)
377
+
378
+ # Handle the repository creation
379
+ if training_args.push_to_hub:
380
+ if training_args.hub_model_id is None:
381
+ repo_name = get_full_repo_name(
382
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
383
+ )
384
+ else:
385
+ repo_name = training_args.hub_model_id
386
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
387
+
388
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
389
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
390
+ # (the dataset will be downloaded automatically from the datasets Hub).
391
+ #
392
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
393
+ # 'text' is found. You can easily tweak this behavior (see below).
394
+ #
395
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
396
+ # download the dataset.
397
+ if data_args.dataset_name is not None:
398
+ # Downloading and loading a dataset from the hub.
399
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
400
+
401
+ if "validation" not in datasets.keys():
402
+ datasets["validation"] = load_dataset(
403
+ data_args.dataset_name,
404
+ data_args.dataset_config_name,
405
+ split=f"train[:{data_args.validation_split_percentage}%]",
406
+ cache_dir=model_args.cache_dir,
407
+ )
408
+ datasets["train"] = load_dataset(
409
+ data_args.dataset_name,
410
+ data_args.dataset_config_name,
411
+ split=f"train[{data_args.validation_split_percentage}%:]",
412
+ cache_dir=model_args.cache_dir,
413
+ )
414
+ else:
415
+ data_files = {}
416
+ if data_args.train_file is not None:
417
+ data_files["train"] = data_args.train_file
418
+ if data_args.validation_file is not None:
419
+ data_files["validation"] = data_args.validation_file
420
+ extension = data_args.train_file.split(".")[-1]
421
+ if extension == "txt":
422
+ extension = "text"
423
+ datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
424
+
425
+ if "validation" not in datasets.keys():
426
+ datasets["validation"] = load_dataset(
427
+ extension,
428
+ data_files=data_files,
429
+ split=f"train[:{data_args.validation_split_percentage}%]",
430
+ cache_dir=model_args.cache_dir,
431
+ )
432
+ datasets["train"] = load_dataset(
433
+ extension,
434
+ data_files=data_files,
435
+ split=f"train[{data_args.validation_split_percentage}%:]",
436
+ cache_dir=model_args.cache_dir,
437
+ )
438
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
439
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
440
+
441
+ # Load pretrained model and tokenizer
442
+
443
+ # Distributed training:
444
+ # The .from_pretrained methods guarantee that only one local process can concurrently
445
+ # download model & vocab.
446
+ if model_args.config_name:
447
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
448
+ elif model_args.model_name_or_path:
449
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
450
+ else:
451
+ config = CONFIG_MAPPING[model_args.model_type]()
452
+ logger.warning("You are instantiating a new config instance from scratch.")
453
+
454
+ if model_args.tokenizer_name:
455
+ tokenizer = AutoTokenizer.from_pretrained(
456
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
457
+ )
458
+ elif model_args.model_name_or_path:
459
+ tokenizer = AutoTokenizer.from_pretrained(
460
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
461
+ )
462
+ else:
463
+ raise ValueError(
464
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
465
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
466
+ )
467
+
468
+ # Preprocessing the datasets.
469
+ # First we tokenize all the texts.
470
+ if training_args.do_train:
471
+ column_names = datasets["train"].column_names
472
+ else:
473
+ column_names = datasets["validation"].column_names
474
+ text_column_name = "text" if "text" in column_names else column_names[0]
475
+
476
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
477
+
478
+ if data_args.line_by_line:
479
+ # When using line_by_line, we just tokenize each nonempty line.
480
+ padding = "max_length" if data_args.pad_to_max_length else False
481
+
482
+ def tokenize_function(examples):
483
+ # Remove empty lines
484
+ examples = [line for line in examples if len(line) > 0 and not line.isspace()]
485
+ return tokenizer(
486
+ examples,
487
+ return_special_tokens_mask=True,
488
+ padding=padding,
489
+ truncation=True,
490
+ max_length=max_seq_length,
491
+ )
492
+
493
+ tokenized_datasets = datasets.map(
494
+ tokenize_function,
495
+ input_columns=[text_column_name],
496
+ batched=True,
497
+ num_proc=data_args.preprocessing_num_workers,
498
+ remove_columns=column_names,
499
+ load_from_cache_file=not data_args.overwrite_cache,
500
+ )
501
+
502
+ else:
503
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
504
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
505
+ # efficient when it receives the `special_tokens_mask`.
506
+ def tokenize_function(examples):
507
+ return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
508
+
509
+ tokenized_datasets = datasets.map(
510
+ tokenize_function,
511
+ batched=True,
512
+ num_proc=data_args.preprocessing_num_workers,
513
+ remove_columns=column_names,
514
+ load_from_cache_file=not data_args.overwrite_cache,
515
+ )
516
+
517
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
518
+ # max_seq_length.
519
+ def group_texts(examples):
520
+ # Concatenate all texts.
521
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
522
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
523
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
524
+ # customize this part to your needs.
525
+ if total_length >= max_seq_length:
526
+ total_length = (total_length // max_seq_length) * max_seq_length
527
+ # Split by chunks of max_len.
528
+ result = {
529
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
530
+ for k, t in concatenated_examples.items()
531
+ }
532
+ return result
533
+
534
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
535
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
536
+ # might be slower to preprocess.
537
+ #
538
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
539
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
540
+ tokenized_datasets = tokenized_datasets.map(
541
+ group_texts,
542
+ batched=True,
543
+ num_proc=data_args.preprocessing_num_workers,
544
+ load_from_cache_file=not data_args.overwrite_cache,
545
+ )
546
+
547
+ # Enable tensorboard only on the master node
548
+ has_tensorboard = is_tensorboard_available()
549
+ if has_tensorboard and jax.process_index() == 0:
550
+ try:
551
+ # Enable Weight&Biases
552
+ import wandb
553
+ wandb.init(
554
+ entity='versae',
555
+ project='roberta-base-ncc',
556
+ sync_tensorboard=False,
557
+ )
558
+ wandb.config.update(training_args)
559
+ wandb.config.update(model_args)
560
+ wandb.config.update(data_args)
561
+
562
+ from flax.metrics.tensorboard import SummaryWriter
563
+
564
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
565
+ except ImportError as ie:
566
+ has_tensorboard = False
567
+ logger.warning(
568
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
569
+ )
570
+ else:
571
+ logger.warning(
572
+ "Unable to display metrics through TensorBoard because the package is not installed: "
573
+ "Please run pip install tensorboard to enable."
574
+ )
575
+
576
+ # Data collator
577
+ # This one will take care of randomly masking the tokens.
578
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
579
+
580
+ # Initialize our training
581
+ rng = jax.random.PRNGKey(training_args.seed)
582
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
583
+
584
+ if model_args.model_name_or_path:
585
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
586
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
587
+ )
588
+ else:
589
+ model = FlaxAutoModelForMaskedLM.from_config(
590
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
591
+ )
592
+
593
+ # Store some constant
594
+ num_epochs = int(training_args.num_train_epochs)
595
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
596
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
597
+
598
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
599
+
600
+ # Create learning rate schedule
601
+ warmup_fn = optax.linear_schedule(
602
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
603
+ )
604
+ decay_fn = optax.linear_schedule(
605
+ init_value=training_args.learning_rate,
606
+ end_value=0,
607
+ transition_steps=num_train_steps - training_args.warmup_steps,
608
+ )
609
+ linear_decay_lr_schedule_fn = optax.join_schedules(
610
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
611
+ )
612
+
613
+ # We use Optax's "masking" functionality to not apply weight decay
614
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
615
+ # mask boolean with the same structure as the parameters.
616
+ # The mask is True for parameters that should be decayed.
617
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
618
+ # For other models, one should correct the layer norm parameter naming
619
+ # accordingly.
620
+ def decay_mask_fn(params):
621
+ flat_params = traverse_util.flatten_dict(params)
622
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
623
+ return traverse_util.unflatten_dict(flat_mask)
624
+
625
+ # create adam optimizer
626
+ if training_args.adafactor:
627
+ # We use the default parameters here to initialize adafactor,
628
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
629
+ optimizer = optax.adafactor(
630
+ learning_rate=linear_decay_lr_schedule_fn,
631
+ )
632
+ else:
633
+ optimizer = optax.adamw(
634
+ learning_rate=linear_decay_lr_schedule_fn,
635
+ b1=training_args.adam_beta1,
636
+ b2=training_args.adam_beta2,
637
+ eps=training_args.adam_epsilon,
638
+ weight_decay=training_args.weight_decay,
639
+ mask=decay_mask_fn,
640
+ )
641
+
642
+ # Setup train state
643
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
644
+
645
+ # Define gradient update step fn
646
+ def train_step(state, batch, dropout_rng):
647
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
648
+
649
+ def loss_fn(params):
650
+ labels = batch.pop("labels")
651
+
652
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
653
+
654
+ # compute loss, ignore padded input tokens
655
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
656
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
657
+
658
+ # take average
659
+ loss = loss.sum() / label_mask.sum()
660
+
661
+ return loss
662
+
663
+ grad_fn = jax.value_and_grad(loss_fn)
664
+ loss, grad = grad_fn(state.params)
665
+ grad = jax.lax.pmean(grad, "batch")
666
+ new_state = state.apply_gradients(grads=grad)
667
+
668
+ metrics = jax.lax.pmean(
669
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
670
+ )
671
+
672
+ return new_state, metrics, new_dropout_rng
673
+
674
+ # Create parallel version of the train step
675
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
676
+
677
+ # Define eval fn
678
+ def eval_step(params, batch):
679
+ labels = batch.pop("labels")
680
+
681
+ logits = model(**batch, params=params, train=False)[0]
682
+
683
+ # compute loss, ignore padded input tokens
684
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
685
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
686
+
687
+ # compute accuracy
688
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
689
+
690
+ # summarize metrics
691
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
692
+ metrics = jax.lax.psum(metrics, axis_name="batch")
693
+
694
+ return metrics
695
+
696
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
697
+
698
+ # Replicate the train state on each device
699
+ state = jax_utils.replicate(state)
700
+
701
+ train_time = 0
702
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
703
+ for epoch in epochs:
704
+ # ======================== Training ================================
705
+ train_start = time.time()
706
+ train_metrics = []
707
+
708
+ # Create sampling rng
709
+ rng, input_rng = jax.random.split(rng)
710
+
711
+ # Generate an epoch by shuffling sampling indices from the train dataset
712
+ num_train_samples = len(tokenized_datasets["train"])
713
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
714
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
715
+
716
+ # Gather the indexes for creating the batch and do a training step
717
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
718
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
719
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
720
+
721
+ # Model forward
722
+ model_inputs = shard(model_inputs.data)
723
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
724
+ train_metrics.append(train_metric)
725
+
726
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
727
+
728
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
729
+ # Save metrics
730
+ train_metric = jax_utils.unreplicate(train_metric)
731
+ train_time += time.time() - train_start
732
+ if has_tensorboard and jax.process_index() == 0:
733
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
734
+
735
+ epochs.write(
736
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
737
+ )
738
+
739
+ train_metrics = []
740
+
741
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
742
+ # ======================== Evaluating ==============================
743
+ num_eval_samples = len(tokenized_datasets["validation"])
744
+ eval_samples_idx = jnp.arange(num_eval_samples)
745
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
746
+
747
+ eval_metrics = []
748
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
749
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
750
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
751
+
752
+ # Model forward
753
+ model_inputs = shard(model_inputs.data)
754
+ metrics = p_eval_step(state.params, model_inputs)
755
+ eval_metrics.append(metrics)
756
+
757
+ # normalize eval metrics
758
+ eval_metrics = get_metrics(eval_metrics)
759
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
760
+ eval_normalizer = eval_metrics.pop("normalizer")
761
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
762
+
763
+ # Update progress bar
764
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
765
+
766
+ # Save metrics
767
+ if has_tensorboard and jax.process_index() == 0:
768
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
769
+
770
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
771
+ # save checkpoint after each epoch and push checkpoint to the hub
772
+ if jax.process_index() == 0:
773
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
774
+ model.save_pretrained(training_args.output_dir, params=params)
775
+ tokenizer.save_pretrained(training_args.output_dir)
776
+ if training_args.push_to_hub:
777
+ repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
778
+
779
+ # Eval after training
780
+ if training_args.do_eval:
781
+ num_eval_samples = len(tokenized_datasets["validation"])
782
+ eval_samples_idx = jnp.arange(num_eval_samples)
783
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
784
+
785
+ eval_metrics = []
786
+ for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
787
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
788
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
789
+
790
+ # Model forward
791
+ model_inputs = shard(model_inputs.data)
792
+ metrics = p_eval_step(state.params, model_inputs)
793
+ eval_metrics.append(metrics)
794
+
795
+ # normalize eval metrics
796
+ eval_metrics = get_metrics(eval_metrics)
797
+ eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
798
+ eval_normalizer = eval_metrics.pop("normalizer")
799
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
800
+
801
+ try:
802
+ perplexity = math.exp(eval_metrics["loss"])
803
+ except OverflowError:
804
+ perplexity = float("inf")
805
+ eval_metrics["perplexity"] = perplexity
806
+
807
+ if jax.process_index() == 0:
808
+ eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
809
+ path = os.path.join(training_args.output_dir, "eval_results.json")
810
+ with open(path, "w") as f:
811
+ json.dump(eval_metrics, f, indent=4, sort_keys=True)
812
+
813
+
814
+ if __name__ == "__main__":
815
+ main()
wandb/run-20220114_234119-1zya86oe/files/config.yaml ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.12.9
7
+ code_path: code/run_mlm_flax.py
8
+ framework: huggingface
9
+ huggingface_version: 4.16.0.dev0
10
+ is_jupyter_run: false
11
+ is_kaggle_kernel: false
12
+ python_version: 3.8.10
13
+ start_time: 1642203679
14
+ t:
15
+ 1:
16
+ - 2
17
+ - 3
18
+ - 11
19
+ - 12
20
+ 2:
21
+ - 2
22
+ - 3
23
+ - 11
24
+ - 12
25
+ 4: 3.8.10
26
+ 5: 0.12.9
27
+ 6: 4.16.0.dev0
28
+ 8:
29
+ - 5
30
+ adafactor:
31
+ desc: null
32
+ value: false
33
+ adam_beta1:
34
+ desc: null
35
+ value: 0.9
36
+ adam_beta2:
37
+ desc: null
38
+ value: 0.98
39
+ adam_epsilon:
40
+ desc: null
41
+ value: 1.0e-06
42
+ cache_dir:
43
+ desc: null
44
+ value: null
45
+ config_name:
46
+ desc: null
47
+ value: roberta-base
48
+ dataset_config_name:
49
+ desc: null
50
+ value: null
51
+ dataset_name:
52
+ desc: null
53
+ value: NbAiLab/NCC
54
+ do_eval:
55
+ desc: null
56
+ value: true
57
+ do_train:
58
+ desc: null
59
+ value: true
60
+ dtype:
61
+ desc: null
62
+ value: bfloat16
63
+ eval_steps:
64
+ desc: null
65
+ value: 1000
66
+ hub_model_id:
67
+ desc: null
68
+ value: null
69
+ hub_token:
70
+ desc: null
71
+ value: null
72
+ learning_rate:
73
+ desc: null
74
+ value: 0.0006
75
+ line_by_line:
76
+ desc: null
77
+ value: false
78
+ logging_steps:
79
+ desc: null
80
+ value: 1000
81
+ max_seq_length:
82
+ desc: null
83
+ value: 128
84
+ mlm_probability:
85
+ desc: null
86
+ value: 0.15
87
+ model_name_or_path:
88
+ desc: null
89
+ value: null
90
+ model_type:
91
+ desc: null
92
+ value: roberta
93
+ num_train_epochs:
94
+ desc: null
95
+ value: 3.0
96
+ output_dir:
97
+ desc: null
98
+ value: ./
99
+ overwrite_cache:
100
+ desc: null
101
+ value: false
102
+ overwrite_output_dir:
103
+ desc: null
104
+ value: true
105
+ pad_to_max_length:
106
+ desc: null
107
+ value: true
108
+ per_device_eval_batch_size:
109
+ desc: null
110
+ value: 232
111
+ per_device_train_batch_size:
112
+ desc: null
113
+ value: 232
114
+ preprocessing_num_workers:
115
+ desc: null
116
+ value: null
117
+ push_to_hub:
118
+ desc: null
119
+ value: true
120
+ save_steps:
121
+ desc: null
122
+ value: 1000
123
+ seed:
124
+ desc: null
125
+ value: 42
126
+ tokenizer_name:
127
+ desc: null
128
+ value: NbAiLab/nb-roberta-base
129
+ train_file:
130
+ desc: null
131
+ value: null
132
+ train_ref_file:
133
+ desc: null
134
+ value: null
135
+ use_fast_tokenizer:
136
+ desc: null
137
+ value: true
138
+ validation_file:
139
+ desc: null
140
+ value: null
141
+ validation_ref_file:
142
+ desc: null
143
+ value: null
144
+ validation_split_percentage:
145
+ desc: null
146
+ value: 5
147
+ warmup_steps:
148
+ desc: null
149
+ value: 10000
150
+ weight_decay:
151
+ desc: null
152
+ value: 0.01
wandb/run-20220114_234119-1zya86oe/files/diff.patch ADDED
File without changes
wandb/run-20220114_234119-1zya86oe/files/output.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85f905cb3060152eb20f5227d157a7605e33c762b80a6b8f4792e5791a1acd2b
3
+ size 26403055
wandb/run-20220114_234119-1zya86oe/files/requirements.txt ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.8.1
3
+ aiosignal==1.2.0
4
+ astunparse==1.6.3
5
+ async-timeout==4.0.2
6
+ attrs==21.4.0
7
+ backcall==0.2.0
8
+ cachetools==4.2.4
9
+ certifi==2021.10.8
10
+ charset-normalizer==2.0.10
11
+ chex==0.1.0
12
+ click==8.0.3
13
+ clu==0.0.6
14
+ configparser==5.2.0
15
+ contextlib2==21.6.0
16
+ cycler==0.11.0
17
+ datasets==1.17.1.dev0
18
+ decorator==5.1.0
19
+ dill==0.3.4
20
+ dm-tree==0.1.6
21
+ docker-pycreds==0.4.0
22
+ filelock==3.4.2
23
+ flatbuffers==2.0
24
+ flax==0.3.6
25
+ fonttools==4.28.5
26
+ frozenlist==1.2.0
27
+ fsspec==2021.11.1
28
+ future==0.18.2
29
+ gast==0.4.0
30
+ gitdb==4.0.9
31
+ gitpython==3.1.26
32
+ google-auth-oauthlib==0.4.6
33
+ google-auth==2.3.3
34
+ google-pasta==0.2.0
35
+ googleapis-common-protos==1.54.0
36
+ grpcio==1.43.0
37
+ h5py==3.6.0
38
+ huggingface-hub==0.2.1
39
+ idna==3.3
40
+ importlib-metadata==4.10.0
41
+ importlib-resources==5.4.0
42
+ ipython==7.31.0
43
+ jax==0.2.26
44
+ jaxlib==0.1.75
45
+ jedi==0.18.1
46
+ joblib==1.1.0
47
+ keras-preprocessing==1.1.2
48
+ keras==2.7.0
49
+ kiwisolver==1.3.2
50
+ libclang==12.0.0
51
+ libtpu-nightly==0.1.dev20211208
52
+ markdown==3.3.6
53
+ matplotlib-inline==0.1.3
54
+ matplotlib==3.5.1
55
+ ml-collections==0.1.0
56
+ msgpack==1.0.3
57
+ multidict==5.2.0
58
+ multiprocess==0.70.12.2
59
+ numpy==1.22.0
60
+ oauthlib==3.1.1
61
+ opt-einsum==3.3.0
62
+ optax==0.1.0
63
+ packaging==21.3
64
+ pandas==1.3.5
65
+ parso==0.8.3
66
+ pathtools==0.1.2
67
+ pexpect==4.8.0
68
+ pickleshare==0.7.5
69
+ pillow==9.0.0
70
+ pip==20.0.2
71
+ pkg-resources==0.0.0
72
+ promise==2.3
73
+ prompt-toolkit==3.0.24
74
+ protobuf==3.19.1
75
+ psutil==5.9.0
76
+ ptyprocess==0.7.0
77
+ pyarrow==6.0.1
78
+ pyasn1-modules==0.2.8
79
+ pyasn1==0.4.8
80
+ pygments==2.11.1
81
+ pyparsing==3.0.6
82
+ python-dateutil==2.8.2
83
+ pytz==2021.3
84
+ pyyaml==6.0
85
+ regex==2021.11.10
86
+ requests-oauthlib==1.3.0
87
+ requests==2.27.0
88
+ rsa==4.8
89
+ sacremoses==0.0.46
90
+ scipy==1.7.3
91
+ sentry-sdk==1.5.2
92
+ setuptools==44.0.0
93
+ shortuuid==1.0.8
94
+ six==1.16.0
95
+ smmap==5.0.0
96
+ subprocess32==3.5.4
97
+ tensorboard-data-server==0.6.1
98
+ tensorboard-plugin-wit==1.8.0
99
+ tensorboard==2.7.0
100
+ tensorflow-cpu==2.7.0
101
+ tensorflow-datasets==4.4.0
102
+ tensorflow-estimator==2.7.0
103
+ tensorflow-io-gcs-filesystem==0.23.1
104
+ tensorflow-metadata==1.5.0
105
+ tensorflow==2.7.0
106
+ termcolor==1.1.0
107
+ tokenizers==0.11.2
108
+ toolz==0.11.2
109
+ tqdm==4.62.3
110
+ traitlets==5.1.1
111
+ transformers==4.16.0.dev0
112
+ typing-extensions==3.10.0.2
113
+ urllib3==1.26.7
114
+ wandb==0.12.9
115
+ wcwidth==0.2.5
116
+ werkzeug==2.0.2
117
+ wheel==0.37.1
118
+ wrapt==1.13.3
119
+ xxhash==2.0.2
120
+ yarl==1.7.2
121
+ yaspin==2.1.0
122
+ zipp==3.7.0
wandb/run-20220114_234119-1zya86oe/files/wandb-metadata.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2022-01-14T23:41:22.464389",
5
+ "startedAt": "2022-01-14T23:41:19.296699",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--output_dir=./",
11
+ "--model_type=roberta",
12
+ "--config_name=roberta-base",
13
+ "--tokenizer_name=NbAiLab/nb-roberta-base",
14
+ "--dataset_name=NbAiLab/NCC",
15
+ "--max_seq_length=128",
16
+ "--weight_decay=0.01",
17
+ "--per_device_train_batch_size=232",
18
+ "--per_device_eval_batch_size=232",
19
+ "--pad_to_max_length",
20
+ "--learning_rate=6e-4",
21
+ "--warmup_steps=10000",
22
+ "--overwrite_output_dir",
23
+ "--num_train_epochs=3",
24
+ "--adam_beta1=0.9",
25
+ "--adam_beta2=0.98",
26
+ "--adam_epsilon=1e-6",
27
+ "--logging_steps=1000",
28
+ "--save_steps=1000",
29
+ "--eval_steps=1000",
30
+ "--do_train",
31
+ "--do_eval",
32
+ "--dtype=bfloat16",
33
+ "--push_to_hub"
34
+ ],
35
+ "state": "running",
36
+ "program": "run_mlm_flax.py",
37
+ "codePath": "run_mlm_flax.py",
38
+ "git": {
39
+ "remote": "https://huggingface.co/versae/roberta-base-ncc",
40
+ "commit": "502df078f73cf93ca9380fcac1c9b9c7598a445f"
41
+ },
42
+ "email": "versae@gmail.com",
43
+ "root": "/data/roberta-base-ncc",
44
+ "host": "t1v-n-eedfb410-w-0",
45
+ "username": "javierr",
46
+ "executable": "/data/flax/bin/python"
47
+ }
wandb/run-20220114_234119-1zya86oe/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_wandb": {"runtime": 356856}}
wandb/run-20220114_234119-1zya86oe/logs/debug-internal.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0a61912984f47a6ce815f00cc68ecb84206e92936ff60cd6573447587dd00aa
3
+ size 38335870
wandb/run-20220114_234119-1zya86oe/logs/debug.log ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2022-01-14 23:41:19,298 INFO MainThread:10537 [wandb_setup.py:_flush():71] setting env: {}
2
+ 2022-01-14 23:41:19,298 INFO MainThread:10537 [wandb_setup.py:_flush():71] setting login settings: {}
3
+ 2022-01-14 23:41:19,298 INFO MainThread:10537 [wandb_init.py:_log_setup():371] Logging user logs to /data/roberta-base-ncc/wandb/run-20220114_234119-1zya86oe/logs/debug.log
4
+ 2022-01-14 23:41:19,298 INFO MainThread:10537 [wandb_init.py:_log_setup():372] Logging internal logs to /data/roberta-base-ncc/wandb/run-20220114_234119-1zya86oe/logs/debug-internal.log
5
+ 2022-01-14 23:41:19,298 INFO MainThread:10537 [wandb_init.py:init():404] calling init triggers
6
+ 2022-01-14 23:41:19,298 INFO MainThread:10537 [wandb_init.py:init():409] wandb.init called with sweep_config: {}
7
+ config: {}
8
+ 2022-01-14 23:41:19,298 INFO MainThread:10537 [wandb_init.py:init():460] starting backend
9
+ 2022-01-14 23:41:19,298 INFO MainThread:10537 [backend.py:_multiprocessing_setup():99] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
10
+ 2022-01-14 23:41:19,328 INFO MainThread:10537 [backend.py:ensure_launched():216] starting backend process...
11
+ 2022-01-14 23:41:19,355 INFO MainThread:10537 [backend.py:ensure_launched():221] started backend process with pid: 11989
12
+ 2022-01-14 23:41:19,357 INFO MainThread:10537 [wandb_init.py:init():469] backend started and connected
13
+ 2022-01-14 23:41:19,369 INFO MainThread:10537 [wandb_init.py:init():533] updated telemetry
14
+ 2022-01-14 23:41:19,437 INFO MainThread:10537 [wandb_init.py:init():563] communicating current version
15
+ 2022-01-14 23:41:20,144 INFO MainThread:10537 [wandb_init.py:init():568] got version response
16
+ 2022-01-14 23:41:20,145 INFO MainThread:10537 [wandb_init.py:init():578] communicating run to backend with 30 second timeout
17
+ 2022-01-14 23:41:20,323 INFO MainThread:10537 [wandb_init.py:init():606] starting run threads in backend
18
+ 2022-01-14 23:41:25,327 INFO MainThread:10537 [wandb_run.py:_console_start():1810] atexit reg
19
+ 2022-01-14 23:41:25,328 INFO MainThread:10537 [wandb_run.py:_redirect():1684] redirect: SettingsConsole.REDIRECT
20
+ 2022-01-14 23:41:25,328 INFO MainThread:10537 [wandb_run.py:_redirect():1689] Redirecting console.
21
+ 2022-01-14 23:41:25,330 INFO MainThread:10537 [wandb_run.py:_redirect():1745] Redirects installed.
22
+ 2022-01-14 23:41:25,331 INFO MainThread:10537 [wandb_init.py:init():633] run started, returning control to user process
23
+ 2022-01-14 23:41:25,331 INFO MainThread:10537 [wandb_run.py:_config_callback():956] config_cb None None {'output_dir': './', 'overwrite_output_dir': True, 'do_train': True, 'do_eval': True, 'per_device_train_batch_size': 232, 'per_device_eval_batch_size': 232, 'learning_rate': 0.0006, 'weight_decay': 0.01, 'adam_beta1': 0.9, 'adam_beta2': 0.98, 'adam_epsilon': 1e-06, 'adafactor': False, 'num_train_epochs': 3.0, 'warmup_steps': 10000, 'logging_steps': 1000, 'save_steps': 1000, 'eval_steps': 1000, 'seed': 42, 'push_to_hub': True, 'hub_model_id': None, 'hub_token': None}
24
+ 2022-01-14 23:41:25,332 INFO MainThread:10537 [wandb_run.py:_config_callback():956] config_cb None None {'model_name_or_path': None, 'model_type': 'roberta', 'config_name': 'roberta-base', 'tokenizer_name': 'NbAiLab/nb-roberta-base', 'cache_dir': None, 'use_fast_tokenizer': True, 'dtype': 'bfloat16'}
25
+ 2022-01-14 23:41:25,332 INFO MainThread:10537 [wandb_run.py:_config_callback():956] config_cb None None {'dataset_name': 'NbAiLab/NCC', 'dataset_config_name': None, 'train_file': None, 'validation_file': None, 'train_ref_file': None, 'validation_ref_file': None, 'overwrite_cache': False, 'validation_split_percentage': 5, 'max_seq_length': 128, 'preprocessing_num_workers': None, 'mlm_probability': 0.15, 'pad_to_max_length': True, 'line_by_line': False}
26
+ 2022-01-19 02:48:53,379 INFO MainThread:10537 [wandb_run.py:_atexit_cleanup():1780] got exitcode: 0
27
+ 2022-01-19 02:48:53,381 INFO MainThread:10537 [wandb_run.py:_restore():1752] restore
28
+ 2022-01-19 02:48:56,346 INFO MainThread:10537 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
29
+ wandb_count: 1
30
+ other_count: 1
31
+ }
32
+ pusher_stats {
33
+ uploaded_bytes: 37446
34
+ total_bytes: 37446
35
+ }
36
+
37
+ 2022-01-19 02:48:56,559 INFO MainThread:10537 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
38
+ wandb_count: 1
39
+ other_count: 1
40
+ }
41
+ pusher_stats {
42
+ uploaded_bytes: 37446
43
+ total_bytes: 37446
44
+ }
45
+
46
+ 2022-01-19 02:48:56,919 INFO MainThread:10537 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
47
+ wandb_count: 5
48
+ other_count: 1
49
+ }
50
+ pusher_stats {
51
+ uploaded_bytes: 37446
52
+ total_bytes: 26444957
53
+ }
54
+
55
+ 2022-01-19 02:48:57,021 INFO MainThread:10537 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
56
+ wandb_count: 5
57
+ other_count: 1
58
+ }
59
+ pusher_stats {
60
+ uploaded_bytes: 37446
61
+ total_bytes: 26444957
62
+ }
63
+
64
+ 2022-01-19 02:48:57,123 INFO MainThread:10537 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
65
+ wandb_count: 5
66
+ other_count: 1
67
+ }
68
+ pusher_stats {
69
+ uploaded_bytes: 10068910
70
+ total_bytes: 26444957
71
+ }
72
+
73
+ 2022-01-19 02:48:57,225 INFO MainThread:10537 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
74
+ wandb_count: 5
75
+ other_count: 1
76
+ }
77
+ pusher_stats {
78
+ uploaded_bytes: 13042606
79
+ total_bytes: 26444957
80
+ }
81
+
82
+ 2022-01-19 02:48:57,327 INFO MainThread:10537 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
83
+ wandb_count: 5
84
+ other_count: 1
85
+ }
86
+ pusher_stats {
87
+ uploaded_bytes: 21463982
88
+ total_bytes: 26444957
89
+ }
90
+
91
+ 2022-01-19 02:48:57,429 INFO MainThread:10537 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
92
+ wandb_count: 5
93
+ other_count: 1
94
+ }
95
+ pusher_stats {
96
+ uploaded_bytes: 26444957
97
+ total_bytes: 26444957
98
+ }
99
+
100
+ 2022-01-19 02:48:57,531 INFO MainThread:10537 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
101
+ wandb_count: 5
102
+ other_count: 1
103
+ }
104
+ pusher_stats {
105
+ uploaded_bytes: 26444957
106
+ total_bytes: 26444957
107
+ }
108
+
109
+ 2022-01-19 02:48:57,633 INFO MainThread:10537 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
110
+ wandb_count: 5
111
+ other_count: 1
112
+ }
113
+ pusher_stats {
114
+ uploaded_bytes: 26444957
115
+ total_bytes: 26444957
116
+ }
117
+
118
+ 2022-01-19 02:48:57,735 INFO MainThread:10537 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
119
+ wandb_count: 5
120
+ other_count: 1
121
+ }
122
+ pusher_stats {
123
+ uploaded_bytes: 26444957
124
+ total_bytes: 26444957
125
+ }
126
+
127
+ 2022-01-19 02:48:57,837 INFO MainThread:10537 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
128
+ wandb_count: 5
129
+ other_count: 1
130
+ }
131
+ pusher_stats {
132
+ uploaded_bytes: 26444957
133
+ total_bytes: 26444957
134
+ }
135
+
136
+ 2022-01-19 02:48:57,939 INFO MainThread:10537 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
137
+ wandb_count: 5
138
+ other_count: 1
139
+ }
140
+ pusher_stats {
141
+ uploaded_bytes: 26444957
142
+ total_bytes: 26444957
143
+ }
144
+
145
+ 2022-01-19 02:48:58,457 INFO MainThread:10537 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
146
+ wandb_count: 5
147
+ other_count: 1
148
+ }
149
+ pusher_stats {
150
+ uploaded_bytes: 26444957
151
+ total_bytes: 26444957
152
+ }
153
+
154
+ 2022-01-19 02:48:58,818 INFO MainThread:10537 [wandb_run.py:_wait_for_finish():1912] got exit ret: done: true
155
+ exit_result {
156
+ }
157
+ file_counts {
158
+ wandb_count: 5
159
+ other_count: 1
160
+ }
161
+ pusher_stats {
162
+ uploaded_bytes: 26444957
163
+ total_bytes: 26444957
164
+ }
165
+ local_info {
166
+ }
167
+
168
+ 2022-01-19 02:49:00,429 INFO MainThread:10537 [wandb_run.py:_append_files():2180] logging synced files
wandb/run-20220114_234119-1zya86oe/run-1zya86oe.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f93d599b590506c08952e582e6ab64f317f0ac3488bf1937504d605ab8ecf5b
3
+ size 118429569
wandb/run-20220119_161158-274aad95/files/code/run_mlm_flax.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=fill-mask
22
+ """
23
+ import json
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ import time
29
+ from dataclasses import asdict, dataclass, field
30
+ from enum import Enum
31
+ from itertools import chain
32
+
33
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
34
+ from pathlib import Path
35
+ from typing import Dict, List, Optional, Tuple
36
+
37
+ import numpy as np
38
+ from datasets import load_dataset
39
+ from tqdm import tqdm
40
+
41
+ import flax
42
+ import jax
43
+ import jax.numpy as jnp
44
+ import optax
45
+ from flax import jax_utils, traverse_util
46
+ from flax.training import train_state
47
+ from flax.training.common_utils import get_metrics, onehot, shard
48
+ from huggingface_hub import Repository
49
+ from transformers import (
50
+ CONFIG_MAPPING,
51
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
52
+ AutoConfig,
53
+ AutoTokenizer,
54
+ FlaxAutoModelForMaskedLM,
55
+ HfArgumentParser,
56
+ PreTrainedTokenizerBase,
57
+ TensorType,
58
+ is_tensorboard_available,
59
+ set_seed,
60
+ )
61
+ from transformers.file_utils import get_full_repo_name
62
+
63
+
64
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
65
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
66
+
67
+
68
+ @dataclass
69
+ class TrainingArguments:
70
+ output_dir: str = field(
71
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
72
+ )
73
+ overwrite_output_dir: bool = field(
74
+ default=False,
75
+ metadata={
76
+ "help": (
77
+ "Overwrite the content of the output directory. "
78
+ "Use this to continue training if output_dir points to a checkpoint directory."
79
+ )
80
+ },
81
+ )
82
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
83
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
84
+ per_device_train_batch_size: int = field(
85
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
86
+ )
87
+ per_device_eval_batch_size: int = field(
88
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
89
+ )
90
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
91
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
92
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
93
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
94
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
95
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
96
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
97
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
98
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
99
+ save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
100
+ eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
101
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
102
+ push_to_hub: bool = field(
103
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
104
+ )
105
+ hub_model_id: str = field(
106
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
107
+ )
108
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
109
+
110
+ def __post_init__(self):
111
+ if self.output_dir is not None:
112
+ self.output_dir = os.path.expanduser(self.output_dir)
113
+
114
+ def to_dict(self):
115
+ """
116
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
117
+ the token values by removing their value.
118
+ """
119
+ d = asdict(self)
120
+ for k, v in d.items():
121
+ if isinstance(v, Enum):
122
+ d[k] = v.value
123
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
124
+ d[k] = [x.value for x in v]
125
+ if k.endswith("_token"):
126
+ d[k] = f"<{k.upper()}>"
127
+ return d
128
+
129
+
130
+ @dataclass
131
+ class ModelArguments:
132
+ """
133
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
134
+ """
135
+
136
+ model_name_or_path: Optional[str] = field(
137
+ default=None,
138
+ metadata={
139
+ "help": "The model checkpoint for weights initialization."
140
+ "Don't set if you want to train a model from scratch."
141
+ },
142
+ )
143
+ model_type: Optional[str] = field(
144
+ default=None,
145
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
146
+ )
147
+ config_name: Optional[str] = field(
148
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
149
+ )
150
+ tokenizer_name: Optional[str] = field(
151
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
152
+ )
153
+ cache_dir: Optional[str] = field(
154
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
155
+ )
156
+ use_fast_tokenizer: bool = field(
157
+ default=True,
158
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
159
+ )
160
+ dtype: Optional[str] = field(
161
+ default="float32",
162
+ metadata={
163
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
164
+ },
165
+ )
166
+
167
+
168
+ @dataclass
169
+ class DataTrainingArguments:
170
+ """
171
+ Arguments pertaining to what data we are going to input our model for training and eval.
172
+ """
173
+
174
+ dataset_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
176
+ )
177
+ dataset_config_name: Optional[str] = field(
178
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
179
+ )
180
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
181
+ validation_file: Optional[str] = field(
182
+ default=None,
183
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
184
+ )
185
+ train_ref_file: Optional[str] = field(
186
+ default=None,
187
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
188
+ )
189
+ validation_ref_file: Optional[str] = field(
190
+ default=None,
191
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
192
+ )
193
+ overwrite_cache: bool = field(
194
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
195
+ )
196
+ validation_split_percentage: Optional[int] = field(
197
+ default=5,
198
+ metadata={
199
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
200
+ },
201
+ )
202
+ max_seq_length: Optional[int] = field(
203
+ default=None,
204
+ metadata={
205
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
206
+ "than this will be truncated. Default to the max input length of the model."
207
+ },
208
+ )
209
+ preprocessing_num_workers: Optional[int] = field(
210
+ default=None,
211
+ metadata={"help": "The number of processes to use for the preprocessing."},
212
+ )
213
+ mlm_probability: float = field(
214
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
215
+ )
216
+ pad_to_max_length: bool = field(
217
+ default=False,
218
+ metadata={
219
+ "help": "Whether to pad all samples to `max_seq_length`. "
220
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
221
+ },
222
+ )
223
+ line_by_line: bool = field(
224
+ default=False,
225
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
226
+ )
227
+
228
+ def __post_init__(self):
229
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
230
+ raise ValueError("Need either a dataset name or a training/validation file.")
231
+ else:
232
+ if self.train_file is not None:
233
+ extension = self.train_file.split(".")[-1]
234
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
235
+ if self.validation_file is not None:
236
+ extension = self.validation_file.split(".")[-1]
237
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
238
+
239
+
240
+ @flax.struct.dataclass
241
+ class FlaxDataCollatorForLanguageModeling:
242
+ """
243
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
244
+ are not all of the same length.
245
+
246
+ Args:
247
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
248
+ The tokenizer used for encoding the data.
249
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
250
+ The probability with which to (randomly) mask tokens in the input.
251
+
252
+ .. note::
253
+
254
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
255
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
256
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
257
+ argument :obj:`return_special_tokens_mask=True`.
258
+ """
259
+
260
+ tokenizer: PreTrainedTokenizerBase
261
+ mlm_probability: float = 0.15
262
+
263
+ def __post_init__(self):
264
+ if self.tokenizer.mask_token is None:
265
+ raise ValueError(
266
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
267
+ "You should pass `mlm=False` to train on causal language modeling instead."
268
+ )
269
+
270
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
271
+ # Handle dict or lists with proper padding and conversion to tensor.
272
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
273
+
274
+ # If special token mask has been preprocessed, pop it from the dict.
275
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
276
+
277
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
278
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
279
+ )
280
+ return batch
281
+
282
+ def mask_tokens(
283
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
284
+ ) -> Tuple[np.ndarray, np.ndarray]:
285
+ """
286
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
287
+ """
288
+ labels = inputs.copy()
289
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
290
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
291
+ special_tokens_mask = special_tokens_mask.astype("bool")
292
+
293
+ probability_matrix[special_tokens_mask] = 0.0
294
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
295
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
296
+
297
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
298
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
299
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
300
+
301
+ # 10% of the time, we replace masked input tokens with random word
302
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
303
+ indices_random &= masked_indices & ~indices_replaced
304
+
305
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
306
+ inputs[indices_random] = random_words[indices_random]
307
+
308
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
309
+ return inputs, labels
310
+
311
+
312
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
313
+ num_samples = len(samples_idx)
314
+ samples_to_remove = num_samples % batch_size
315
+
316
+ if samples_to_remove != 0:
317
+ samples_idx = samples_idx[:-samples_to_remove]
318
+ sections_split = num_samples // batch_size
319
+ batch_idx = np.split(samples_idx, sections_split)
320
+ return batch_idx
321
+
322
+
323
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
324
+ summary_writer.scalar("train_time", train_time, step)
325
+
326
+ train_metrics = get_metrics(train_metrics)
327
+ for key, vals in train_metrics.items():
328
+ tag = f"train_{key}"
329
+ for i, val in enumerate(vals):
330
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
331
+
332
+
333
+ def write_eval_metric(summary_writer, eval_metrics, step):
334
+ for metric_name, value in eval_metrics.items():
335
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
336
+
337
+
338
+ def main():
339
+ # See all possible arguments in src/transformers/training_args.py
340
+ # or by passing the --help flag to this script.
341
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
342
+
343
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
344
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
345
+ # If we pass only one argument to the script and it's the path to a json file,
346
+ # let's parse it to get our arguments.
347
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
348
+ else:
349
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
350
+
351
+ if (
352
+ os.path.exists(training_args.output_dir)
353
+ and os.listdir(training_args.output_dir)
354
+ and training_args.do_train
355
+ and not training_args.overwrite_output_dir
356
+ ):
357
+ raise ValueError(
358
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
359
+ "Use --overwrite_output_dir to overcome."
360
+ )
361
+
362
+ # Setup logging
363
+ logging.basicConfig(
364
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
365
+ level=logging.INFO,
366
+ datefmt="[%X]",
367
+ )
368
+
369
+ # Log on each process the small summary:
370
+ logger = logging.getLogger(__name__)
371
+
372
+ # Set the verbosity to info of the Transformers logger (on main process only):
373
+ logger.info(f"Training/evaluation parameters {training_args}")
374
+
375
+ # Set seed before initializing model.
376
+ set_seed(training_args.seed)
377
+
378
+ # Handle the repository creation
379
+ if training_args.push_to_hub:
380
+ if training_args.hub_model_id is None:
381
+ repo_name = get_full_repo_name(
382
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
383
+ )
384
+ else:
385
+ repo_name = training_args.hub_model_id
386
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
387
+
388
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
389
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
390
+ # (the dataset will be downloaded automatically from the datasets Hub).
391
+ #
392
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
393
+ # 'text' is found. You can easily tweak this behavior (see below).
394
+ #
395
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
396
+ # download the dataset.
397
+ if data_args.dataset_name is not None:
398
+ # Downloading and loading a dataset from the hub.
399
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
400
+
401
+ if "validation" not in datasets.keys():
402
+ datasets["validation"] = load_dataset(
403
+ data_args.dataset_name,
404
+ data_args.dataset_config_name,
405
+ split=f"train[:{data_args.validation_split_percentage}%]",
406
+ cache_dir=model_args.cache_dir,
407
+ )
408
+ datasets["train"] = load_dataset(
409
+ data_args.dataset_name,
410
+ data_args.dataset_config_name,
411
+ split=f"train[{data_args.validation_split_percentage}%:]",
412
+ cache_dir=model_args.cache_dir,
413
+ )
414
+ else:
415
+ data_files = {}
416
+ if data_args.train_file is not None:
417
+ data_files["train"] = data_args.train_file
418
+ if data_args.validation_file is not None:
419
+ data_files["validation"] = data_args.validation_file
420
+ extension = data_args.train_file.split(".")[-1]
421
+ if extension == "txt":
422
+ extension = "text"
423
+ datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
424
+
425
+ if "validation" not in datasets.keys():
426
+ datasets["validation"] = load_dataset(
427
+ extension,
428
+ data_files=data_files,
429
+ split=f"train[:{data_args.validation_split_percentage}%]",
430
+ cache_dir=model_args.cache_dir,
431
+ )
432
+ datasets["train"] = load_dataset(
433
+ extension,
434
+ data_files=data_files,
435
+ split=f"train[{data_args.validation_split_percentage}%:]",
436
+ cache_dir=model_args.cache_dir,
437
+ )
438
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
439
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
440
+
441
+ # Load pretrained model and tokenizer
442
+
443
+ # Distributed training:
444
+ # The .from_pretrained methods guarantee that only one local process can concurrently
445
+ # download model & vocab.
446
+ if model_args.config_name:
447
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
448
+ elif model_args.model_name_or_path:
449
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
450
+ else:
451
+ config = CONFIG_MAPPING[model_args.model_type]()
452
+ logger.warning("You are instantiating a new config instance from scratch.")
453
+
454
+ if model_args.tokenizer_name:
455
+ tokenizer = AutoTokenizer.from_pretrained(
456
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
457
+ )
458
+ elif model_args.model_name_or_path:
459
+ tokenizer = AutoTokenizer.from_pretrained(
460
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
461
+ )
462
+ else:
463
+ raise ValueError(
464
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
465
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
466
+ )
467
+
468
+ # Preprocessing the datasets.
469
+ # First we tokenize all the texts.
470
+ if training_args.do_train:
471
+ column_names = datasets["train"].column_names
472
+ else:
473
+ column_names = datasets["validation"].column_names
474
+ text_column_name = "text" if "text" in column_names else column_names[0]
475
+
476
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
477
+
478
+ if data_args.line_by_line:
479
+ # When using line_by_line, we just tokenize each nonempty line.
480
+ padding = "max_length" if data_args.pad_to_max_length else False
481
+
482
+ def tokenize_function(examples):
483
+ # Remove empty lines
484
+ examples = [line for line in examples if len(line) > 0 and not line.isspace()]
485
+ return tokenizer(
486
+ examples,
487
+ return_special_tokens_mask=True,
488
+ padding=padding,
489
+ truncation=True,
490
+ max_length=max_seq_length,
491
+ )
492
+
493
+ tokenized_datasets = datasets.map(
494
+ tokenize_function,
495
+ input_columns=[text_column_name],
496
+ batched=True,
497
+ num_proc=data_args.preprocessing_num_workers,
498
+ remove_columns=column_names,
499
+ load_from_cache_file=not data_args.overwrite_cache,
500
+ )
501
+
502
+ else:
503
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
504
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
505
+ # efficient when it receives the `special_tokens_mask`.
506
+ def tokenize_function(examples):
507
+ return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
508
+
509
+ tokenized_datasets = datasets.map(
510
+ tokenize_function,
511
+ batched=True,
512
+ num_proc=data_args.preprocessing_num_workers,
513
+ remove_columns=column_names,
514
+ load_from_cache_file=not data_args.overwrite_cache,
515
+ )
516
+
517
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
518
+ # max_seq_length.
519
+ def group_texts(examples):
520
+ # Concatenate all texts.
521
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
522
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
523
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
524
+ # customize this part to your needs.
525
+ if total_length >= max_seq_length:
526
+ total_length = (total_length // max_seq_length) * max_seq_length
527
+ # Split by chunks of max_len.
528
+ result = {
529
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
530
+ for k, t in concatenated_examples.items()
531
+ }
532
+ return result
533
+
534
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
535
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
536
+ # might be slower to preprocess.
537
+ #
538
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
539
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
540
+ tokenized_datasets = tokenized_datasets.map(
541
+ group_texts,
542
+ batched=True,
543
+ num_proc=data_args.preprocessing_num_workers,
544
+ load_from_cache_file=not data_args.overwrite_cache,
545
+ )
546
+
547
+ # Enable tensorboard only on the master node
548
+ has_tensorboard = is_tensorboard_available()
549
+ if has_tensorboard and jax.process_index() == 0:
550
+ try:
551
+ # Enable Weight&Biases
552
+ import wandb
553
+ wandb.init(
554
+ entity='versae',
555
+ project='roberta-base-ncc',
556
+ sync_tensorboard=False,
557
+ )
558
+ wandb.config.update(training_args)
559
+ wandb.config.update(model_args)
560
+ wandb.config.update(data_args)
561
+
562
+ from flax.metrics.tensorboard import SummaryWriter
563
+
564
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
565
+ except ImportError as ie:
566
+ has_tensorboard = False
567
+ logger.warning(
568
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
569
+ )
570
+ else:
571
+ logger.warning(
572
+ "Unable to display metrics through TensorBoard because the package is not installed: "
573
+ "Please run pip install tensorboard to enable."
574
+ )
575
+
576
+ # Data collator
577
+ # This one will take care of randomly masking the tokens.
578
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
579
+
580
+ # Initialize our training
581
+ rng = jax.random.PRNGKey(training_args.seed)
582
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
583
+
584
+ if model_args.model_name_or_path:
585
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
586
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
587
+ )
588
+ else:
589
+ model = FlaxAutoModelForMaskedLM.from_config(
590
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
591
+ )
592
+
593
+ # Store some constant
594
+ num_epochs = int(training_args.num_train_epochs)
595
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
596
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
597
+
598
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
599
+
600
+ # Create learning rate schedule
601
+ warmup_fn = optax.linear_schedule(
602
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
603
+ )
604
+ decay_fn = optax.linear_schedule(
605
+ init_value=training_args.learning_rate,
606
+ end_value=0,
607
+ transition_steps=num_train_steps - training_args.warmup_steps,
608
+ )
609
+ linear_decay_lr_schedule_fn = optax.join_schedules(
610
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
611
+ )
612
+
613
+ # We use Optax's "masking" functionality to not apply weight decay
614
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
615
+ # mask boolean with the same structure as the parameters.
616
+ # The mask is True for parameters that should be decayed.
617
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
618
+ # For other models, one should correct the layer norm parameter naming
619
+ # accordingly.
620
+ def decay_mask_fn(params):
621
+ flat_params = traverse_util.flatten_dict(params)
622
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
623
+ return traverse_util.unflatten_dict(flat_mask)
624
+
625
+ # create adam optimizer
626
+ if training_args.adafactor:
627
+ # We use the default parameters here to initialize adafactor,
628
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
629
+ optimizer = optax.adafactor(
630
+ learning_rate=linear_decay_lr_schedule_fn,
631
+ )
632
+ else:
633
+ optimizer = optax.adamw(
634
+ learning_rate=linear_decay_lr_schedule_fn,
635
+ b1=training_args.adam_beta1,
636
+ b2=training_args.adam_beta2,
637
+ eps=training_args.adam_epsilon,
638
+ weight_decay=training_args.weight_decay,
639
+ mask=decay_mask_fn,
640
+ )
641
+
642
+ # Setup train state
643
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
644
+
645
+ # Define gradient update step fn
646
+ def train_step(state, batch, dropout_rng):
647
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
648
+
649
+ def loss_fn(params):
650
+ labels = batch.pop("labels")
651
+
652
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
653
+
654
+ # compute loss, ignore padded input tokens
655
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
656
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
657
+
658
+ # take average
659
+ loss = loss.sum() / label_mask.sum()
660
+
661
+ return loss
662
+
663
+ grad_fn = jax.value_and_grad(loss_fn)
664
+ loss, grad = grad_fn(state.params)
665
+ grad = jax.lax.pmean(grad, "batch")
666
+ new_state = state.apply_gradients(grads=grad)
667
+
668
+ metrics = jax.lax.pmean(
669
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
670
+ )
671
+
672
+ return new_state, metrics, new_dropout_rng
673
+
674
+ # Create parallel version of the train step
675
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
676
+
677
+ # Define eval fn
678
+ def eval_step(params, batch):
679
+ labels = batch.pop("labels")
680
+
681
+ logits = model(**batch, params=params, train=False)[0]
682
+
683
+ # compute loss, ignore padded input tokens
684
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
685
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
686
+
687
+ # compute accuracy
688
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
689
+
690
+ # summarize metrics
691
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
692
+ metrics = jax.lax.psum(metrics, axis_name="batch")
693
+
694
+ return metrics
695
+
696
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
697
+
698
+ # Replicate the train state on each device
699
+ state = jax_utils.replicate(state)
700
+
701
+ train_time = 0
702
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
703
+ for epoch in epochs:
704
+ # ======================== Training ================================
705
+ train_start = time.time()
706
+ train_metrics = []
707
+
708
+ # Create sampling rng
709
+ rng, input_rng = jax.random.split(rng)
710
+
711
+ # Generate an epoch by shuffling sampling indices from the train dataset
712
+ num_train_samples = len(tokenized_datasets["train"])
713
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
714
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
715
+
716
+ # Gather the indexes for creating the batch and do a training step
717
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
718
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
719
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
720
+
721
+ # Model forward
722
+ model_inputs = shard(model_inputs.data)
723
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
724
+ train_metrics.append(train_metric)
725
+
726
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
727
+
728
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
729
+ # Save metrics
730
+ train_metric = jax_utils.unreplicate(train_metric)
731
+ train_time += time.time() - train_start
732
+ if has_tensorboard and jax.process_index() == 0:
733
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
734
+
735
+ epochs.write(
736
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
737
+ )
738
+
739
+ train_metrics = []
740
+
741
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
742
+ # ======================== Evaluating ==============================
743
+ num_eval_samples = len(tokenized_datasets["validation"])
744
+ eval_samples_idx = jnp.arange(num_eval_samples)
745
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
746
+
747
+ eval_metrics = []
748
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
749
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
750
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
751
+
752
+ # Model forward
753
+ model_inputs = shard(model_inputs.data)
754
+ metrics = p_eval_step(state.params, model_inputs)
755
+ eval_metrics.append(metrics)
756
+
757
+ # normalize eval metrics
758
+ eval_metrics = get_metrics(eval_metrics)
759
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
760
+ eval_normalizer = eval_metrics.pop("normalizer")
761
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
762
+
763
+ # Update progress bar
764
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
765
+
766
+ # Save metrics
767
+ if has_tensorboard and jax.process_index() == 0:
768
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
769
+
770
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
771
+ # save checkpoint after each epoch and push checkpoint to the hub
772
+ if jax.process_index() == 0:
773
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
774
+ model.save_pretrained(training_args.output_dir, params=params)
775
+ tokenizer.save_pretrained(training_args.output_dir)
776
+ if training_args.push_to_hub:
777
+ repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
778
+
779
+ # Eval after training
780
+ if training_args.do_eval:
781
+ num_eval_samples = len(tokenized_datasets["validation"])
782
+ eval_samples_idx = jnp.arange(num_eval_samples)
783
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
784
+
785
+ eval_metrics = []
786
+ for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
787
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
788
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
789
+
790
+ # Model forward
791
+ model_inputs = shard(model_inputs.data)
792
+ metrics = p_eval_step(state.params, model_inputs)
793
+ eval_metrics.append(metrics)
794
+
795
+ # normalize eval metrics
796
+ eval_metrics = get_metrics(eval_metrics)
797
+ eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
798
+ eval_normalizer = eval_metrics.pop("normalizer")
799
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
800
+
801
+ try:
802
+ perplexity = math.exp(eval_metrics["loss"])
803
+ except OverflowError:
804
+ perplexity = float("inf")
805
+ eval_metrics["perplexity"] = perplexity
806
+
807
+ if jax.process_index() == 0:
808
+ eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
809
+ path = os.path.join(training_args.output_dir, "eval_results.json")
810
+ with open(path, "w") as f:
811
+ json.dump(eval_metrics, f, indent=4, sort_keys=True)
812
+
813
+
814
+ if __name__ == "__main__":
815
+ main()
wandb/run-20220119_161158-274aad95/files/config.yaml ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.12.9
7
+ code_path: code/run_mlm_flax.py
8
+ framework: huggingface
9
+ huggingface_version: 4.16.0.dev0
10
+ is_jupyter_run: false
11
+ is_kaggle_kernel: false
12
+ python_version: 3.8.10
13
+ start_time: 1642608719
14
+ t:
15
+ 1:
16
+ - 2
17
+ - 3
18
+ - 11
19
+ - 12
20
+ 4: 3.8.10
21
+ 5: 0.12.9
22
+ 6: 4.16.0.dev0
23
+ 8:
24
+ - 5
25
+ adafactor:
26
+ desc: null
27
+ value: false
28
+ adam_beta1:
29
+ desc: null
30
+ value: 0.9
31
+ adam_beta2:
32
+ desc: null
33
+ value: 0.98
34
+ adam_epsilon:
35
+ desc: null
36
+ value: 1.0e-06
37
+ cache_dir:
38
+ desc: null
39
+ value: null
40
+ config_name:
41
+ desc: null
42
+ value: ./
43
+ dataset_config_name:
44
+ desc: null
45
+ value: null
46
+ dataset_name:
47
+ desc: null
48
+ value: NbAiLab/NCC
49
+ do_eval:
50
+ desc: null
51
+ value: true
52
+ do_train:
53
+ desc: null
54
+ value: true
55
+ dtype:
56
+ desc: null
57
+ value: bfloat16
58
+ eval_steps:
59
+ desc: null
60
+ value: 1000
61
+ hub_model_id:
62
+ desc: null
63
+ value: null
64
+ hub_token:
65
+ desc: null
66
+ value: null
67
+ learning_rate:
68
+ desc: null
69
+ value: 0.0006
70
+ line_by_line:
71
+ desc: null
72
+ value: false
73
+ logging_steps:
74
+ desc: null
75
+ value: 1000
76
+ max_seq_length:
77
+ desc: null
78
+ value: 512
79
+ mlm_probability:
80
+ desc: null
81
+ value: 0.15
82
+ model_name_or_path:
83
+ desc: null
84
+ value: ./
85
+ model_type:
86
+ desc: null
87
+ value: roberta
88
+ num_train_epochs:
89
+ desc: null
90
+ value: 3.0
91
+ output_dir:
92
+ desc: null
93
+ value: ./
94
+ overwrite_cache:
95
+ desc: null
96
+ value: false
97
+ overwrite_output_dir:
98
+ desc: null
99
+ value: true
100
+ pad_to_max_length:
101
+ desc: null
102
+ value: true
103
+ per_device_eval_batch_size:
104
+ desc: null
105
+ value: 46
106
+ per_device_train_batch_size:
107
+ desc: null
108
+ value: 46
109
+ preprocessing_num_workers:
110
+ desc: null
111
+ value: null
112
+ push_to_hub:
113
+ desc: null
114
+ value: true
115
+ save_steps:
116
+ desc: null
117
+ value: 1000
118
+ seed:
119
+ desc: null
120
+ value: 42
121
+ tokenizer_name:
122
+ desc: null
123
+ value: ./
124
+ train_file:
125
+ desc: null
126
+ value: null
127
+ train_ref_file:
128
+ desc: null
129
+ value: null
130
+ use_fast_tokenizer:
131
+ desc: null
132
+ value: true
133
+ validation_file:
134
+ desc: null
135
+ value: null
136
+ validation_ref_file:
137
+ desc: null
138
+ value: null
139
+ validation_split_percentage:
140
+ desc: null
141
+ value: 5
142
+ warmup_steps:
143
+ desc: null
144
+ value: 1000
145
+ weight_decay:
146
+ desc: null
147
+ value: 0.01