aapot commited on
Commit
9da3171
1 Parent(s): b152c5e

Saving weights and logs of step 10000

Browse files
config.json CHANGED
@@ -19,7 +19,7 @@
19
  "num_hidden_layers": 24,
20
  "pad_token_id": 1,
21
  "position_embedding_type": "absolute",
22
- "transformers_version": "4.13.0.dev0",
23
  "type_vocab_size": 1,
24
  "use_cache": true,
25
  "vocab_size": 50265
19
  "num_hidden_layers": 24,
20
  "pad_token_id": 1,
21
  "position_embedding_type": "absolute",
22
+ "transformers_version": "4.11.0",
23
  "type_vocab_size": 1,
24
  "use_cache": true,
25
  "vocab_size": 50265
events.out.tfevents.1637651508.t1v-n-8eba1090-w-0.74811.0.v2 → events.out.tfevents.1637788246.t1v-n-8eba1090-w-0.278309.0.v2 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a698de76a6eef179b50ae1f446a42233b23305d204aea21414c9c719c958894a
3
- size 8912195
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cedc456912f39cab6a93851a70164212aeff38666e4e7d06802401d8ff4983c9
3
+ size 1470757
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3da37fdf3d6ea94d5fcc73090e445b143f824569cf20c9d2a5779a5566dd3c7d
3
- size 1421662309
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f89a2f1cb697fef6bf98fda870fd214efe0fb3874f01fdc75e5beaed3bef05d0
3
+ size 711588089
run_mlm_flax.py CHANGED
@@ -16,7 +16,6 @@
16
  """
17
  Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
  text file or a dataset.
19
-
20
  Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
  https://huggingface.co/models?filter=masked-lm
22
  """
@@ -25,15 +24,12 @@ import os
25
  import sys
26
  import time
27
  from dataclasses import dataclass, field
28
-
29
  # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
30
  from pathlib import Path
31
  from typing import Dict, List, Optional, Tuple
32
-
33
  import numpy as np
34
  from datasets import load_dataset, load_from_disk
35
  from tqdm import tqdm
36
-
37
  import flax
38
  import jax
39
  import jax.numpy as jnp
@@ -41,7 +37,6 @@ import optax
41
  from flax import jax_utils, traverse_util
42
  from flax.training import train_state
43
  from flax.training.common_utils import get_metrics, onehot, shard
44
- from huggingface_hub import Repository
45
  from transformers import (
46
  CONFIG_MAPPING,
47
  FLAX_MODEL_FOR_MASKED_LM_MAPPING,
@@ -55,19 +50,13 @@ from transformers import (
55
  is_tensorboard_available,
56
  set_seed,
57
  )
58
- from transformers.file_utils import get_full_repo_name
59
-
60
-
61
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
62
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
63
-
64
-
65
  @dataclass
66
  class ModelArguments:
67
  """
68
  Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
69
  """
70
-
71
  model_name_or_path: Optional[str] = field(
72
  default=None,
73
  metadata={
@@ -98,14 +87,11 @@ class ModelArguments:
98
  "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
99
  },
100
  )
101
-
102
-
103
  @dataclass
104
  class DataTrainingArguments:
105
  """
106
  Arguments pertaining to what data we are going to input our model for training and eval.
107
  """
108
-
109
  dataset_name: Optional[str] = field(
110
  default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
111
  )
@@ -168,7 +154,6 @@ class DataTrainingArguments:
168
  default=False,
169
  metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
170
  )
171
-
172
  def __post_init__(self):
173
  if self.dataset_name is None and self.train_file is None and self.dataset_filepath is None and self.validation_file is None:
174
  raise ValueError("Need either a dataset name or a training/validation file.")
@@ -179,50 +164,39 @@ class DataTrainingArguments:
179
  if self.validation_file is not None:
180
  extension = self.validation_file.split(".")[-1]
181
  assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
182
-
183
-
184
  @flax.struct.dataclass
185
  class FlaxDataCollatorForLanguageModeling:
186
  """
187
  Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
188
  are not all of the same length.
189
-
190
  Args:
191
  tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
192
  The tokenizer used for encoding the data.
193
  mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
194
  The probability with which to (randomly) mask tokens in the input.
195
-
196
  .. note::
197
-
198
  For best performance, this data collator should be used with a dataset having items that are dictionaries or
199
  BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
200
  :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
201
  argument :obj:`return_special_tokens_mask=True`.
202
  """
203
-
204
  tokenizer: PreTrainedTokenizerBase
205
  mlm_probability: float = 0.15
206
-
207
  def __post_init__(self):
208
  if self.tokenizer.mask_token is None:
209
  raise ValueError(
210
  "This tokenizer does not have a mask token which is necessary for masked language modeling. "
211
  "You should pass `mlm=False` to train on causal language modeling instead."
212
  )
213
-
214
  def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
215
  # Handle dict or lists with proper padding and conversion to tensor.
216
  batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
217
-
218
  # If special token mask has been preprocessed, pop it from the dict.
219
  special_tokens_mask = batch.pop("special_tokens_mask", None)
220
-
221
  batch["input_ids"], batch["labels"] = self.mask_tokens(
222
  batch["input_ids"], special_tokens_mask=special_tokens_mask
223
  )
224
  return batch
225
-
226
  def mask_tokens(
227
  self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
228
  ) -> Tuple[np.ndarray, np.ndarray]:
@@ -233,57 +207,41 @@ class FlaxDataCollatorForLanguageModeling:
233
  # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
234
  probability_matrix = np.full(labels.shape, self.mlm_probability)
235
  special_tokens_mask = special_tokens_mask.astype("bool")
236
-
237
  probability_matrix[special_tokens_mask] = 0.0
238
  masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
239
  labels[~masked_indices] = -100 # We only compute loss on masked tokens
240
-
241
  # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
242
  indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
243
  inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
244
-
245
  # 10% of the time, we replace masked input tokens with random word
246
  indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
247
  indices_random &= masked_indices & ~indices_replaced
248
-
249
  random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
250
  inputs[indices_random] = random_words[indices_random]
251
-
252
  # The rest of the time (10% of the time) we keep the masked input tokens unchanged
253
  return inputs, labels
254
-
255
-
256
  def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
257
  num_samples = len(samples_idx)
258
  samples_to_remove = num_samples % batch_size
259
-
260
  if samples_to_remove != 0:
261
  samples_idx = samples_idx[:-samples_to_remove]
262
  sections_split = num_samples // batch_size
263
  batch_idx = np.split(samples_idx, sections_split)
264
  return batch_idx
265
-
266
-
267
  def write_train_metric(summary_writer, train_metrics, train_time, step):
268
  summary_writer.scalar("train_time", train_time, step)
269
-
270
  train_metrics = get_metrics(train_metrics)
271
  for key, vals in train_metrics.items():
272
  tag = f"train_{key}"
273
  for i, val in enumerate(vals):
274
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
275
-
276
-
277
  def write_eval_metric(summary_writer, eval_metrics, step):
278
  for metric_name, value in eval_metrics.items():
279
  summary_writer.scalar(f"eval_{metric_name}", value, step)
280
-
281
-
282
  if __name__ == "__main__":
283
  # See all possible arguments in src/transformers/training_args.py
284
  # or by passing the --help flag to this script.
285
  # We now keep distinct sets of args, for a cleaner separation of concerns.
286
-
287
  parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
288
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
289
  # If we pass only one argument to the script and it's the path to a json file,
@@ -291,7 +249,6 @@ if __name__ == "__main__":
291
  model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
292
  else:
293
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
294
-
295
  if (
296
  os.path.exists(training_args.output_dir)
297
  and os.listdir(training_args.output_dir)
@@ -302,33 +259,18 @@ if __name__ == "__main__":
302
  f"Output directory ({training_args.output_dir}) already exists and is not empty."
303
  "Use --overwrite_output_dir to overcome."
304
  )
305
-
306
  # Setup logging
307
  logging.basicConfig(
308
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
309
  level="NOTSET",
310
  datefmt="[%X]",
311
  )
312
-
313
  # Log on each process the small summary:
314
  logger = logging.getLogger(__name__)
315
-
316
  # Set the verbosity to info of the Transformers logger (on main process only):
317
  logger.info(f"Training/evaluation parameters {training_args}")
318
-
319
  # Set seed before initializing model.
320
  set_seed(training_args.seed)
321
-
322
- # Handle the repository creation
323
- if training_args.push_to_hub:
324
- if training_args.hub_model_id is None:
325
- repo_name = get_full_repo_name(
326
- Path(training_args.output_dir).absolute().name, token=training_args.hub_token
327
- )
328
- else:
329
- repo_name = training_args.hub_model_id
330
- repo = Repository(training_args.output_dir, clone_from=repo_name)
331
-
332
  # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
333
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
334
  # (the dataset will be downloaded automatically from the datasets Hub).
@@ -341,7 +283,6 @@ if __name__ == "__main__":
341
  if data_args.dataset_name is not None:
342
  # Downloading and loading a dataset from the hub.
343
  datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
344
-
345
  if "validation" not in datasets.keys():
346
  datasets["validation"] = load_dataset(
347
  data_args.dataset_name,
@@ -355,7 +296,6 @@ if __name__ == "__main__":
355
  split=f"train[{data_args.validation_split_percentage}%:]",
356
  cache_dir=model_args.cache_dir,
357
  )
358
-
359
  elif data_args.dataset_filepath is not None:
360
  # Loading a dataset from local file.
361
  datasets = load_from_disk(data_args.dataset_filepath)
@@ -363,7 +303,6 @@ if __name__ == "__main__":
363
  datasets = datasets.train_test_split(test_size=data_args.validation_split_percentage/100)
364
  datasets["validation"] = datasets["test"]
365
  del datasets["test"]
366
-
367
  else:
368
  data_files = {}
369
  if data_args.train_file is not None:
@@ -374,7 +313,6 @@ if __name__ == "__main__":
374
  if extension == "txt":
375
  extension = "text"
376
  datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
377
-
378
  if "validation" not in datasets.keys():
379
  datasets["validation"] = load_dataset(
380
  extension,
@@ -390,9 +328,7 @@ if __name__ == "__main__":
390
  )
391
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
392
  # https://huggingface.co/docs/datasets/loading_datasets.html.
393
-
394
  # Load pretrained model and tokenizer
395
-
396
  # Distributed training:
397
  # The .from_pretrained methods guarantee that only one local process can concurrently
398
  # download model & vocab.
@@ -403,7 +339,6 @@ if __name__ == "__main__":
403
  else:
404
  config = CONFIG_MAPPING[model_args.model_type]()
405
  logger.warning("You are instantiating a new config instance from scratch.")
406
-
407
  if model_args.tokenizer_name:
408
  tokenizer = AutoTokenizer.from_pretrained(
409
  model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
@@ -417,7 +352,6 @@ if __name__ == "__main__":
417
  "You are instantiating a new tokenizer from scratch. This is not supported by this script."
418
  "You can do it from another script, save it, and load it from here, using --tokenizer_name."
419
  )
420
-
421
  # Preprocessing the datasets.
422
  # First we tokenize all the texts.
423
  if training_args.do_train:
@@ -425,13 +359,10 @@ if __name__ == "__main__":
425
  else:
426
  column_names = datasets["validation"].column_names
427
  text_column_name = "text" if "text" in column_names else column_names[0]
428
-
429
  max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
430
-
431
  if data_args.line_by_line:
432
  # When using line_by_line, we just tokenize each nonempty line.
433
  padding = "max_length" if data_args.pad_to_max_length else False
434
-
435
  def tokenize_function(examples):
436
  # Remove empty lines
437
  examples = [line for line in examples if len(line) > 0 and not line.isspace()]
@@ -442,7 +373,6 @@ if __name__ == "__main__":
442
  truncation=True,
443
  max_length=max_seq_length,
444
  )
445
-
446
  if data_args.tokenized_dataset_filepath is not None:
447
  # Loading a tokenized dataset from local file.
448
  tokenized_datasets = load_from_disk(data_args.tokenized_dataset_filepath)
@@ -455,19 +385,16 @@ if __name__ == "__main__":
455
  remove_columns=column_names,
456
  load_from_cache_file=not data_args.overwrite_cache,
457
  )
458
-
459
  else:
460
  # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
461
  # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
462
  # efficient when it receives the `special_tokens_mask`.
463
  def tokenize_function(examples):
464
  return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
465
-
466
  if data_args.tokenized_dataset_filepath is not None:
467
  # Loading a tokenized dataset from local file.
468
  tokenized_datasets = load_from_disk(data_args.tokenized_dataset_filepath)
469
  else:
470
-
471
  tokenized_datasets = datasets.map(
472
  tokenize_function,
473
  batched=True,
@@ -475,7 +402,6 @@ if __name__ == "__main__":
475
  remove_columns=column_names,
476
  load_from_cache_file=not data_args.overwrite_cache,
477
  )
478
-
479
  # Main data processing function that will concatenate all texts from our dataset and generate chunks of
480
  # max_seq_length.
481
  def group_texts(examples):
@@ -492,7 +418,6 @@ if __name__ == "__main__":
492
  for k, t in concatenated_examples.items()
493
  }
494
  return result
495
-
496
  # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
497
  # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
498
  # might be slower to preprocess.
@@ -505,18 +430,23 @@ if __name__ == "__main__":
505
  num_proc=data_args.preprocessing_num_workers,
506
  load_from_cache_file=not data_args.overwrite_cache,
507
  )
508
-
509
  # save the tokenized dataset for future runs
510
  if data_args.save_tokenized_dataset_filepath is not None:
 
 
 
 
 
 
 
 
511
  tokenized_datasets.save_to_disk(data_args.save_tokenized_dataset_filepath)
512
-
513
-
514
  # Enable tensorboard only on the master node
515
  has_tensorboard = is_tensorboard_available()
516
  if has_tensorboard and jax.process_index() == 0:
517
  try:
518
  from flax.metrics.tensorboard import SummaryWriter
519
-
520
  summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
521
  except ImportError as ie:
522
  has_tensorboard = False
@@ -528,15 +458,12 @@ if __name__ == "__main__":
528
  "Unable to display metrics through TensorBoard because the package is not installed: "
529
  "Please run pip install tensorboard to enable."
530
  )
531
-
532
  # Data collator
533
  # This one will take care of randomly masking the tokens.
534
  data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
535
-
536
  # Initialize our training
537
  rng = jax.random.PRNGKey(training_args.seed)
538
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
539
-
540
  if model_args.model_name_or_path:
541
  model = FlaxAutoModelForMaskedLM.from_pretrained(
542
  model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
@@ -545,14 +472,11 @@ if __name__ == "__main__":
545
  model = FlaxAutoModelForMaskedLM.from_config(
546
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
547
  )
548
-
549
  # Store some constant
550
  num_epochs = int(training_args.num_train_epochs)
551
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
552
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
553
-
554
  num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
555
-
556
  # Create learning rate schedule
557
  warmup_fn = optax.linear_schedule(
558
  init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
@@ -565,7 +489,6 @@ if __name__ == "__main__":
565
  linear_decay_lr_schedule_fn = optax.join_schedules(
566
  schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
567
  )
568
-
569
  # We use Optax's "masking" functionality to not apply weight decay
570
  # to bias and LayerNorm scale parameters. decay_mask_fn returns a
571
  # mask boolean with the same structure as the parameters.
@@ -577,7 +500,6 @@ if __name__ == "__main__":
577
  flat_params = traverse_util.flatten_dict(params)
578
  flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
579
  return traverse_util.unflatten_dict(flat_mask)
580
-
581
  # create adam optimizer
582
  if training_args.adafactor:
583
  # We use the default parameters here to initialize adafactor,
@@ -594,153 +516,121 @@ if __name__ == "__main__":
594
  weight_decay=training_args.weight_decay,
595
  mask=decay_mask_fn,
596
  )
597
-
598
  # Setup train state
599
  state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
600
-
601
  # Define gradient update step fn
602
  def train_step(state, batch, dropout_rng):
603
  dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
604
-
605
  def loss_fn(params):
606
  labels = batch.pop("labels")
607
-
608
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
609
-
610
  # compute loss, ignore padded input tokens
611
  label_mask = jnp.where(labels > 0, 1.0, 0.0)
612
  loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
613
-
614
  # take average
615
  loss = loss.sum() / label_mask.sum()
616
-
617
  return loss
618
-
619
  grad_fn = jax.value_and_grad(loss_fn)
620
  loss, grad = grad_fn(state.params)
621
  grad = jax.lax.pmean(grad, "batch")
622
  new_state = state.apply_gradients(grads=grad)
623
-
624
  metrics = jax.lax.pmean(
625
  {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
626
  )
627
-
628
  return new_state, metrics, new_dropout_rng
629
-
630
  # Create parallel version of the train step
631
  p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
632
-
633
  # Define eval fn
634
  def eval_step(params, batch):
635
  labels = batch.pop("labels")
636
-
637
  logits = model(**batch, params=params, train=False)[0]
638
-
639
  # compute loss, ignore padded input tokens
640
  label_mask = jnp.where(labels > 0, 1.0, 0.0)
641
  loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
642
-
643
  # compute accuracy
644
  accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
645
-
646
  # summarize metrics
647
  metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
648
  metrics = jax.lax.psum(metrics, axis_name="batch")
649
-
650
  return metrics
651
-
652
  p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
653
-
654
  # Replicate the train state on each device
655
  state = jax_utils.replicate(state)
656
-
657
  train_time = 0
658
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
659
  for epoch in epochs:
660
  # ======================== Training ================================
661
  train_start = time.time()
662
  train_metrics = []
663
-
664
  # Create sampling rng
665
  rng, input_rng = jax.random.split(rng)
666
-
667
  # Generate an epoch by shuffling sampling indices from the train dataset
668
  num_train_samples = len(tokenized_datasets["train"])
669
  train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
670
  train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
671
-
672
  # Gather the indexes for creating the batch and do a training step
673
  for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
674
  samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
675
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
676
-
677
  # Model forward
678
  model_inputs = shard(model_inputs.data)
679
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
680
  train_metrics.append(train_metric)
681
-
682
  cur_step = epoch * (num_train_samples // train_batch_size) + step
683
-
684
  if cur_step % training_args.logging_steps == 0 and cur_step > 0:
685
  # Save metrics
686
  train_metric = jax_utils.unreplicate(train_metric)
687
  train_time += time.time() - train_start
688
  if has_tensorboard and jax.process_index() == 0:
689
  write_train_metric(summary_writer, train_metrics, train_time, cur_step)
690
-
691
  epochs.write(
692
  f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
693
  )
694
-
695
  train_metrics = []
696
-
697
  if cur_step % training_args.eval_steps == 0 and cur_step > 0:
698
  # ======================== Evaluating ==============================
699
  num_eval_samples = len(tokenized_datasets["validation"])
700
  eval_samples_idx = jnp.arange(num_eval_samples)
701
  eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
702
-
703
  eval_metrics = []
704
  for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
705
  samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
706
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
707
-
708
  # Model forward
709
  model_inputs = shard(model_inputs.data)
710
  metrics = p_eval_step(state.params, model_inputs)
711
  eval_metrics.append(metrics)
712
-
713
  # normalize eval metrics
714
  eval_metrics = get_metrics(eval_metrics)
715
  eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
716
  eval_normalizer = eval_metrics.pop("normalizer")
717
  eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
718
-
719
  # Update progress bar
720
  epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
721
-
722
  # Save metrics
723
  if has_tensorboard and jax.process_index() == 0:
724
  write_eval_metric(summary_writer, eval_metrics, cur_step)
725
-
726
  if cur_step % training_args.save_steps == 0 and cur_step > 0:
727
  # save checkpoint after each epoch and push checkpoint to the hub
728
  if jax.process_index() == 0:
729
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
730
- model.save_pretrained(training_args.output_dir, params=params)
731
- tokenizer.save_pretrained(training_args.output_dir)
732
- if training_args.push_to_hub:
733
- repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
734
-
 
 
735
  # save also at the end of epoch
736
  try:
737
  if jax.process_index() == 0:
738
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
739
- model.save_pretrained(training_args.output_dir, params=params)
740
- tokenizer.save_pretrained(training_args.output_dir)
741
- if training_args.push_to_hub:
742
- repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
 
 
743
  except:
744
  # push to hub fails the whole script if nothing new to commit
745
- pass
746
-
16
  """
17
  Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
  text file or a dataset.
 
19
  Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
  https://huggingface.co/models?filter=masked-lm
21
  """
24
  import sys
25
  import time
26
  from dataclasses import dataclass, field
 
27
  # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
28
  from pathlib import Path
29
  from typing import Dict, List, Optional, Tuple
 
30
  import numpy as np
31
  from datasets import load_dataset, load_from_disk
32
  from tqdm import tqdm
 
33
  import flax
34
  import jax
35
  import jax.numpy as jnp
37
  from flax import jax_utils, traverse_util
38
  from flax.training import train_state
39
  from flax.training.common_utils import get_metrics, onehot, shard
 
40
  from transformers import (
41
  CONFIG_MAPPING,
42
  FLAX_MODEL_FOR_MASKED_LM_MAPPING,
50
  is_tensorboard_available,
51
  set_seed,
52
  )
 
 
 
53
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
54
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
 
 
55
  @dataclass
56
  class ModelArguments:
57
  """
58
  Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
59
  """
 
60
  model_name_or_path: Optional[str] = field(
61
  default=None,
62
  metadata={
87
  "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
88
  },
89
  )
 
 
90
  @dataclass
91
  class DataTrainingArguments:
92
  """
93
  Arguments pertaining to what data we are going to input our model for training and eval.
94
  """
 
95
  dataset_name: Optional[str] = field(
96
  default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
97
  )
154
  default=False,
155
  metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
156
  )
 
157
  def __post_init__(self):
158
  if self.dataset_name is None and self.train_file is None and self.dataset_filepath is None and self.validation_file is None:
159
  raise ValueError("Need either a dataset name or a training/validation file.")
164
  if self.validation_file is not None:
165
  extension = self.validation_file.split(".")[-1]
166
  assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
 
 
167
  @flax.struct.dataclass
168
  class FlaxDataCollatorForLanguageModeling:
169
  """
170
  Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
171
  are not all of the same length.
 
172
  Args:
173
  tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
174
  The tokenizer used for encoding the data.
175
  mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
176
  The probability with which to (randomly) mask tokens in the input.
 
177
  .. note::
 
178
  For best performance, this data collator should be used with a dataset having items that are dictionaries or
179
  BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
180
  :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
181
  argument :obj:`return_special_tokens_mask=True`.
182
  """
 
183
  tokenizer: PreTrainedTokenizerBase
184
  mlm_probability: float = 0.15
 
185
  def __post_init__(self):
186
  if self.tokenizer.mask_token is None:
187
  raise ValueError(
188
  "This tokenizer does not have a mask token which is necessary for masked language modeling. "
189
  "You should pass `mlm=False` to train on causal language modeling instead."
190
  )
 
191
  def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
192
  # Handle dict or lists with proper padding and conversion to tensor.
193
  batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
 
194
  # If special token mask has been preprocessed, pop it from the dict.
195
  special_tokens_mask = batch.pop("special_tokens_mask", None)
 
196
  batch["input_ids"], batch["labels"] = self.mask_tokens(
197
  batch["input_ids"], special_tokens_mask=special_tokens_mask
198
  )
199
  return batch
 
200
  def mask_tokens(
201
  self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
202
  ) -> Tuple[np.ndarray, np.ndarray]:
207
  # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
208
  probability_matrix = np.full(labels.shape, self.mlm_probability)
209
  special_tokens_mask = special_tokens_mask.astype("bool")
 
210
  probability_matrix[special_tokens_mask] = 0.0
211
  masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
212
  labels[~masked_indices] = -100 # We only compute loss on masked tokens
 
213
  # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
214
  indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
215
  inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
 
216
  # 10% of the time, we replace masked input tokens with random word
217
  indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
218
  indices_random &= masked_indices & ~indices_replaced
 
219
  random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
220
  inputs[indices_random] = random_words[indices_random]
 
221
  # The rest of the time (10% of the time) we keep the masked input tokens unchanged
222
  return inputs, labels
 
 
223
  def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
224
  num_samples = len(samples_idx)
225
  samples_to_remove = num_samples % batch_size
 
226
  if samples_to_remove != 0:
227
  samples_idx = samples_idx[:-samples_to_remove]
228
  sections_split = num_samples // batch_size
229
  batch_idx = np.split(samples_idx, sections_split)
230
  return batch_idx
 
 
231
  def write_train_metric(summary_writer, train_metrics, train_time, step):
232
  summary_writer.scalar("train_time", train_time, step)
 
233
  train_metrics = get_metrics(train_metrics)
234
  for key, vals in train_metrics.items():
235
  tag = f"train_{key}"
236
  for i, val in enumerate(vals):
237
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
 
 
238
  def write_eval_metric(summary_writer, eval_metrics, step):
239
  for metric_name, value in eval_metrics.items():
240
  summary_writer.scalar(f"eval_{metric_name}", value, step)
 
 
241
  if __name__ == "__main__":
242
  # See all possible arguments in src/transformers/training_args.py
243
  # or by passing the --help flag to this script.
244
  # We now keep distinct sets of args, for a cleaner separation of concerns.
 
245
  parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
246
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
247
  # If we pass only one argument to the script and it's the path to a json file,
249
  model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
250
  else:
251
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
 
252
  if (
253
  os.path.exists(training_args.output_dir)
254
  and os.listdir(training_args.output_dir)
259
  f"Output directory ({training_args.output_dir}) already exists and is not empty."
260
  "Use --overwrite_output_dir to overcome."
261
  )
 
262
  # Setup logging
263
  logging.basicConfig(
264
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
265
  level="NOTSET",
266
  datefmt="[%X]",
267
  )
 
268
  # Log on each process the small summary:
269
  logger = logging.getLogger(__name__)
 
270
  # Set the verbosity to info of the Transformers logger (on main process only):
271
  logger.info(f"Training/evaluation parameters {training_args}")
 
272
  # Set seed before initializing model.
273
  set_seed(training_args.seed)
 
 
 
 
 
 
 
 
 
 
 
274
  # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
275
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
276
  # (the dataset will be downloaded automatically from the datasets Hub).
283
  if data_args.dataset_name is not None:
284
  # Downloading and loading a dataset from the hub.
285
  datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
 
286
  if "validation" not in datasets.keys():
287
  datasets["validation"] = load_dataset(
288
  data_args.dataset_name,
296
  split=f"train[{data_args.validation_split_percentage}%:]",
297
  cache_dir=model_args.cache_dir,
298
  )
 
299
  elif data_args.dataset_filepath is not None:
300
  # Loading a dataset from local file.
301
  datasets = load_from_disk(data_args.dataset_filepath)
303
  datasets = datasets.train_test_split(test_size=data_args.validation_split_percentage/100)
304
  datasets["validation"] = datasets["test"]
305
  del datasets["test"]
 
306
  else:
307
  data_files = {}
308
  if data_args.train_file is not None:
313
  if extension == "txt":
314
  extension = "text"
315
  datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
 
316
  if "validation" not in datasets.keys():
317
  datasets["validation"] = load_dataset(
318
  extension,
328
  )
329
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
330
  # https://huggingface.co/docs/datasets/loading_datasets.html.
 
331
  # Load pretrained model and tokenizer
 
332
  # Distributed training:
333
  # The .from_pretrained methods guarantee that only one local process can concurrently
334
  # download model & vocab.
339
  else:
340
  config = CONFIG_MAPPING[model_args.model_type]()
341
  logger.warning("You are instantiating a new config instance from scratch.")
 
342
  if model_args.tokenizer_name:
343
  tokenizer = AutoTokenizer.from_pretrained(
344
  model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
352
  "You are instantiating a new tokenizer from scratch. This is not supported by this script."
353
  "You can do it from another script, save it, and load it from here, using --tokenizer_name."
354
  )
 
355
  # Preprocessing the datasets.
356
  # First we tokenize all the texts.
357
  if training_args.do_train:
359
  else:
360
  column_names = datasets["validation"].column_names
361
  text_column_name = "text" if "text" in column_names else column_names[0]
 
362
  max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
 
363
  if data_args.line_by_line:
364
  # When using line_by_line, we just tokenize each nonempty line.
365
  padding = "max_length" if data_args.pad_to_max_length else False
 
366
  def tokenize_function(examples):
367
  # Remove empty lines
368
  examples = [line for line in examples if len(line) > 0 and not line.isspace()]
373
  truncation=True,
374
  max_length=max_seq_length,
375
  )
 
376
  if data_args.tokenized_dataset_filepath is not None:
377
  # Loading a tokenized dataset from local file.
378
  tokenized_datasets = load_from_disk(data_args.tokenized_dataset_filepath)
385
  remove_columns=column_names,
386
  load_from_cache_file=not data_args.overwrite_cache,
387
  )
 
388
  else:
389
  # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
390
  # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
391
  # efficient when it receives the `special_tokens_mask`.
392
  def tokenize_function(examples):
393
  return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
 
394
  if data_args.tokenized_dataset_filepath is not None:
395
  # Loading a tokenized dataset from local file.
396
  tokenized_datasets = load_from_disk(data_args.tokenized_dataset_filepath)
397
  else:
 
398
  tokenized_datasets = datasets.map(
399
  tokenize_function,
400
  batched=True,
402
  remove_columns=column_names,
403
  load_from_cache_file=not data_args.overwrite_cache,
404
  )
 
405
  # Main data processing function that will concatenate all texts from our dataset and generate chunks of
406
  # max_seq_length.
407
  def group_texts(examples):
418
  for k, t in concatenated_examples.items()
419
  }
420
  return result
 
421
  # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
422
  # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
423
  # might be slower to preprocess.
430
  num_proc=data_args.preprocessing_num_workers,
431
  load_from_cache_file=not data_args.overwrite_cache,
432
  )
433
+
434
  # save the tokenized dataset for future runs
435
  if data_args.save_tokenized_dataset_filepath is not None:
436
+ if data_args.dataset_filepath is not None:
437
+ try:
438
+ os.system(f"sudo rm {data_args.dataset_filepath}/train/cache*")
439
+ os.system(f"sudo rm {data_args.dataset_filepath}/validation/cache*")
440
+ os.system(f"sudo rm {data_args.dataset_filepath}/train/tmp*")
441
+ os.system(f"sudo rm {data_args.dataset_filepath}/validation/tmp*")
442
+ except:
443
+ pass
444
  tokenized_datasets.save_to_disk(data_args.save_tokenized_dataset_filepath)
 
 
445
  # Enable tensorboard only on the master node
446
  has_tensorboard = is_tensorboard_available()
447
  if has_tensorboard and jax.process_index() == 0:
448
  try:
449
  from flax.metrics.tensorboard import SummaryWriter
 
450
  summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
451
  except ImportError as ie:
452
  has_tensorboard = False
458
  "Unable to display metrics through TensorBoard because the package is not installed: "
459
  "Please run pip install tensorboard to enable."
460
  )
 
461
  # Data collator
462
  # This one will take care of randomly masking the tokens.
463
  data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
 
464
  # Initialize our training
465
  rng = jax.random.PRNGKey(training_args.seed)
466
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
 
467
  if model_args.model_name_or_path:
468
  model = FlaxAutoModelForMaskedLM.from_pretrained(
469
  model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
472
  model = FlaxAutoModelForMaskedLM.from_config(
473
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
474
  )
 
475
  # Store some constant
476
  num_epochs = int(training_args.num_train_epochs)
477
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
478
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
 
479
  num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
 
480
  # Create learning rate schedule
481
  warmup_fn = optax.linear_schedule(
482
  init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
489
  linear_decay_lr_schedule_fn = optax.join_schedules(
490
  schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
491
  )
 
492
  # We use Optax's "masking" functionality to not apply weight decay
493
  # to bias and LayerNorm scale parameters. decay_mask_fn returns a
494
  # mask boolean with the same structure as the parameters.
500
  flat_params = traverse_util.flatten_dict(params)
501
  flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
502
  return traverse_util.unflatten_dict(flat_mask)
 
503
  # create adam optimizer
504
  if training_args.adafactor:
505
  # We use the default parameters here to initialize adafactor,
516
  weight_decay=training_args.weight_decay,
517
  mask=decay_mask_fn,
518
  )
 
519
  # Setup train state
520
  state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
 
521
  # Define gradient update step fn
522
  def train_step(state, batch, dropout_rng):
523
  dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
 
524
  def loss_fn(params):
525
  labels = batch.pop("labels")
 
526
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
 
527
  # compute loss, ignore padded input tokens
528
  label_mask = jnp.where(labels > 0, 1.0, 0.0)
529
  loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
 
530
  # take average
531
  loss = loss.sum() / label_mask.sum()
 
532
  return loss
 
533
  grad_fn = jax.value_and_grad(loss_fn)
534
  loss, grad = grad_fn(state.params)
535
  grad = jax.lax.pmean(grad, "batch")
536
  new_state = state.apply_gradients(grads=grad)
 
537
  metrics = jax.lax.pmean(
538
  {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
539
  )
 
540
  return new_state, metrics, new_dropout_rng
 
541
  # Create parallel version of the train step
542
  p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
 
543
  # Define eval fn
544
  def eval_step(params, batch):
545
  labels = batch.pop("labels")
 
546
  logits = model(**batch, params=params, train=False)[0]
 
547
  # compute loss, ignore padded input tokens
548
  label_mask = jnp.where(labels > 0, 1.0, 0.0)
549
  loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
 
550
  # compute accuracy
551
  accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
 
552
  # summarize metrics
553
  metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
554
  metrics = jax.lax.psum(metrics, axis_name="batch")
 
555
  return metrics
 
556
  p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
 
557
  # Replicate the train state on each device
558
  state = jax_utils.replicate(state)
 
559
  train_time = 0
560
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
561
  for epoch in epochs:
562
  # ======================== Training ================================
563
  train_start = time.time()
564
  train_metrics = []
 
565
  # Create sampling rng
566
  rng, input_rng = jax.random.split(rng)
 
567
  # Generate an epoch by shuffling sampling indices from the train dataset
568
  num_train_samples = len(tokenized_datasets["train"])
569
  train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
570
  train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
 
571
  # Gather the indexes for creating the batch and do a training step
572
  for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
573
  samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
574
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
 
575
  # Model forward
576
  model_inputs = shard(model_inputs.data)
577
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
578
  train_metrics.append(train_metric)
 
579
  cur_step = epoch * (num_train_samples // train_batch_size) + step
 
580
  if cur_step % training_args.logging_steps == 0 and cur_step > 0:
581
  # Save metrics
582
  train_metric = jax_utils.unreplicate(train_metric)
583
  train_time += time.time() - train_start
584
  if has_tensorboard and jax.process_index() == 0:
585
  write_train_metric(summary_writer, train_metrics, train_time, cur_step)
 
586
  epochs.write(
587
  f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
588
  )
 
589
  train_metrics = []
 
590
  if cur_step % training_args.eval_steps == 0 and cur_step > 0:
591
  # ======================== Evaluating ==============================
592
  num_eval_samples = len(tokenized_datasets["validation"])
593
  eval_samples_idx = jnp.arange(num_eval_samples)
594
  eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
 
595
  eval_metrics = []
596
  for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
597
  samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
598
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
 
599
  # Model forward
600
  model_inputs = shard(model_inputs.data)
601
  metrics = p_eval_step(state.params, model_inputs)
602
  eval_metrics.append(metrics)
 
603
  # normalize eval metrics
604
  eval_metrics = get_metrics(eval_metrics)
605
  eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
606
  eval_normalizer = eval_metrics.pop("normalizer")
607
  eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
 
608
  # Update progress bar
609
  epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
 
610
  # Save metrics
611
  if has_tensorboard and jax.process_index() == 0:
612
  write_eval_metric(summary_writer, eval_metrics, cur_step)
 
613
  if cur_step % training_args.save_steps == 0 and cur_step > 0:
614
  # save checkpoint after each epoch and push checkpoint to the hub
615
  if jax.process_index() == 0:
616
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
617
+ model.save_pretrained(
618
+ training_args.output_dir,
619
+ params=params,
620
+ push_to_hub=training_args.push_to_hub,
621
+ commit_message=f"Saving weights and logs of step {cur_step}",
622
+ )
623
+
624
  # save also at the end of epoch
625
  try:
626
  if jax.process_index() == 0:
627
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
628
+ model.save_pretrained(
629
+ training_args.output_dir,
630
+ params=params,
631
+ push_to_hub=training_args.push_to_hub,
632
+ commit_message=f"Saving weights and logs of epoch {epoch}",
633
+ )
634
  except:
635
  # push to hub fails the whole script if nothing new to commit
636
+ pass
 
start_train.sh CHANGED
@@ -17,7 +17,7 @@ python3 run_mlm_flax.py \
17
  --adam_beta2="0.98" \
18
  --adam_epsilon="1e-6" \
19
  --learning_rate="2e-4" \
20
- --warmup_steps="25000" \
21
  --overwrite_output_dir \
22
  --num_train_epochs="2" \
23
  --save_strategy="steps" \
@@ -27,5 +27,4 @@ python3 run_mlm_flax.py \
27
  --logging_steps="1000" \
28
  --dtype="bfloat16" \
29
  --push_to_hub \
30
- --hub_model_id="Finnish-NLP/roberta-large-finnish-v2" \
31
- --adafactor
17
  --adam_beta2="0.98" \
18
  --adam_epsilon="1e-6" \
19
  --learning_rate="2e-4" \
20
+ --warmup_steps="1500" \
21
  --overwrite_output_dir \
22
  --num_train_epochs="2" \
23
  --save_strategy="steps" \
27
  --logging_steps="1000" \
28
  --dtype="bfloat16" \
29
  --push_to_hub \
30
+ --hub_model_id="Finnish-NLP/roberta-large-finnish-v2"