amank commited on
Commit
139e10d
1 Parent(s): 39dde54

Updated code to work with streaming version

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
.vscode/launch.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "Python: StreamingFile",
9
+ "type": "python",
10
+ "request": "launch",
11
+ "program": "${file}",
12
+ "args": [
13
+ "--output_dir","roberta_mc4_sentence_piece",
14
+ "--model_type","roberta",
15
+ "--config_name","roberta_mc4_sentence_piece",
16
+ "--tokenizer_name","roberta_mc4_sentence_piece",
17
+ "--dataset_name","mc4",
18
+ "--dataset_config_name","hi",
19
+ "--max_seq_length","256",
20
+ "--per_device_train_batch_size","128",
21
+ "--per_device_eval_batch_size","128",
22
+ "--learning_rate","3e-4",
23
+ "--warmup_steps","1000",
24
+ "--overwrite_output_dir",
25
+ "--adam_beta1","0.9",
26
+ "--adam_beta2","0.98",
27
+ "--num_train_steps","10000",
28
+ "--num_eval_samples","5000",
29
+ "--logging_steps","250",
30
+ "--eval_steps","1000"
31
+ ],
32
+ "console": "integratedTerminal"
33
+ },
34
+ {
35
+ "name": "Python: Current File",
36
+ "type": "python",
37
+ "request": "launch",
38
+ "program": "${file}",
39
+ "console": "integratedTerminal"
40
+ }
41
+ ]
42
+ }
events.out.tfevents.1625416432.t1v-n-9df4ce0e-w-0.447041.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:10ad97376c517e5e19a89e76161722dd57fe0da7a5aa8bb2b16eb3234749e607
3
- size 40
 
 
 
 
events.out.tfevents.1625418057.t1v-n-9df4ce0e-w-0.452509.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e6c4826820ff2695d2564e7e0bd9b2d1b181883ba8fd1e89493d0379a4425454
3
- size 41580506
 
 
 
 
flax_model.msgpack DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:904705cb48bfd59a5d3aed1f30746844238567349459d7e3e09aca20dbfb1e35
3
- size 498796983
 
 
 
 
pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4887e2e3b00807be6e343d2aa0abd4d3fd8eb844d48d55ef39efa22b0e1ba0b1
3
- size 498877970
 
 
 
 
roberta_mc4_sentence_piece/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.6.1",
23
+ "type_vocab_size": 1,
24
+ "use_cache": true,
25
+ "vocab_size": 50265
26
+ }
roberta_mc4_sentence_piece/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
run_mlm_flax.py CHANGED
@@ -40,6 +40,7 @@ import jax.numpy as jnp
40
  import optax
41
  from flax import jax_utils, traverse_util
42
  from flax.training import train_state
 
43
  from flax.training.common_utils import get_metrics, onehot, shard
44
  from transformers import (
45
  CONFIG_MAPPING,
@@ -56,22 +57,6 @@ from transformers import (
56
  )
57
 
58
 
59
- # Cache the result
60
- has_tensorboard = is_tensorboard_available()
61
- if has_tensorboard:
62
- try:
63
- from flax.metrics.tensorboard import SummaryWriter
64
- except ImportError as ie:
65
- has_tensorboard = False
66
- print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
67
-
68
- else:
69
- print(
70
- "Unable to display metrics through TensorBoard because the package is not installed: "
71
- "Please run pip install tensorboard to enable."
72
- )
73
-
74
-
75
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
76
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
77
 
@@ -186,6 +171,7 @@ class DataTrainingArguments:
186
  assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
187
 
188
 
 
189
  @flax.struct.dataclass
190
  class FlaxDataCollatorForLanguageModeling:
191
  """
@@ -269,7 +255,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
269
  return batch_idx
270
 
271
 
272
- def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
273
  summary_writer.scalar("train_time", train_time, step)
274
 
275
  train_metrics = get_metrics(train_metrics)
@@ -278,6 +264,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
278
  for i, val in enumerate(vals):
279
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
280
 
 
 
281
  for metric_name, value in eval_metrics.items():
282
  summary_writer.scalar(f"eval_{metric_name}", value, step)
283
 
@@ -315,16 +303,11 @@ if __name__ == "__main__":
315
 
316
  # Log on each process the small summary:
317
  logger = logging.getLogger(__name__)
318
- logger.warning(
319
- f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
320
- + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
321
- )
322
 
323
  # Set the verbosity to info of the Transformers logger (on main process only):
324
  logger.info(f"Training/evaluation parameters {training_args}")
325
 
326
  # Set seed before initializing model.
327
- training_args.seed = 42
328
  set_seed(training_args.seed)
329
 
330
  # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
@@ -363,6 +346,19 @@ if __name__ == "__main__":
363
  if extension == "txt":
364
  extension = "text"
365
  datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
367
  # https://huggingface.co/docs/datasets/loading_datasets.html.
368
 
@@ -450,7 +446,8 @@ if __name__ == "__main__":
450
  total_length = len(concatenated_examples[list(examples.keys())[0]])
451
  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
452
  # customize this part to your needs.
453
- total_length = (total_length // max_seq_length) * max_seq_length
 
454
  # Split by chunks of max_len.
455
  result = {
456
  k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
@@ -472,8 +469,22 @@ if __name__ == "__main__":
472
  )
473
 
474
  # Enable tensorboard only on the master node
 
475
  if has_tensorboard and jax.process_index() == 0:
476
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
 
 
 
 
 
 
 
 
 
 
 
 
 
477
 
478
  # Data collator
479
  # This one will take care of randomly masking the tokens.
@@ -483,7 +494,14 @@ if __name__ == "__main__":
483
  rng = jax.random.PRNGKey(training_args.seed)
484
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
485
 
486
- model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
 
 
 
 
 
 
 
487
 
488
  # Store some constant
489
  num_epochs = int(training_args.num_train_epochs)
@@ -518,17 +536,24 @@ if __name__ == "__main__":
518
  return traverse_util.unflatten_dict(flat_mask)
519
 
520
  # create adam optimizer
521
- adamw = optax.adamw(
522
- learning_rate=linear_decay_lr_schedule_fn,
523
- b1=training_args.adam_beta1,
524
- b2=training_args.adam_beta2,
525
- eps=1e-8,
526
- weight_decay=training_args.weight_decay,
527
- mask=decay_mask_fn,
528
- )
 
 
 
 
 
 
 
529
 
530
  # Setup train state
531
- state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
532
 
533
  # Define gradient update step fn
534
  def train_step(state, batch, dropout_rng):
@@ -588,7 +613,6 @@ if __name__ == "__main__":
588
 
589
  train_time = 0
590
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
591
- save_checkpoint=True
592
  for epoch in epochs:
593
  # ======================== Training ================================
594
  train_start = time.time()
@@ -603,7 +627,7 @@ if __name__ == "__main__":
603
  train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
604
 
605
  # Gather the indexes for creating the batch and do a training step
606
- for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
607
  samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
608
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
609
 
@@ -611,58 +635,63 @@ if __name__ == "__main__":
611
  model_inputs = shard(model_inputs.data)
612
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
613
  train_metrics.append(train_metric)
614
- if save_checkpoint and (train_metric['loss'] < 1.).all():
615
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
616
- model.save_pretrained(
617
- '/home/khandelia1000/checkpoints/',
618
- params=params,
619
- push_to_hub=False
620
- )
621
- save_checkpoint = False
622
 
623
- train_time += time.time() - train_start
624
 
625
- epochs.write(
626
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
627
- )
628
-
629
- # ======================== Evaluating ==============================
630
- num_eval_samples = len(tokenized_datasets["validation"])
631
- eval_samples_idx = jnp.arange(num_eval_samples)
632
- eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
633
 
634
- eval_metrics = []
635
- for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
636
- samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
637
- model_inputs = data_collator(samples, pad_to_multiple_of=16)
638
-
639
- # Model forward
640
- model_inputs = shard(model_inputs.data)
641
- metrics = p_eval_step(state.params, model_inputs)
642
- eval_metrics.append(metrics)
643
-
644
- # normalize eval metrics
645
- eval_metrics = get_metrics(eval_metrics)
646
- eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
647
- eval_normalizer = eval_metrics.pop("normalizer")
648
- eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
649
-
650
- # Update progress bar
651
- epochs.desc = (
652
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
653
- )
654
 
655
- # Save metrics
656
- if has_tensorboard and jax.process_index() == 0:
657
- cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
658
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
659
-
660
- # save checkpoint after each epoch and push checkpoint to the hub
661
- if jax.process_index() == 0:
662
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
663
- model.save_pretrained(
664
- training_args.output_dir,
665
- params=params,
666
- push_to_hub=training_args.push_to_hub,
667
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
668
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  import optax
41
  from flax import jax_utils, traverse_util
42
  from flax.training import train_state
43
+ from flax.serialization import from_bytes, to_bytes
44
  from flax.training.common_utils import get_metrics, onehot, shard
45
  from transformers import (
46
  CONFIG_MAPPING,
 
57
  )
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
61
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
62
 
 
171
  assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
172
 
173
 
174
+
175
  @flax.struct.dataclass
176
  class FlaxDataCollatorForLanguageModeling:
177
  """
 
255
  return batch_idx
256
 
257
 
258
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
259
  summary_writer.scalar("train_time", train_time, step)
260
 
261
  train_metrics = get_metrics(train_metrics)
 
264
  for i, val in enumerate(vals):
265
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
266
 
267
+
268
+ def write_eval_metric(summary_writer, eval_metrics, step):
269
  for metric_name, value in eval_metrics.items():
270
  summary_writer.scalar(f"eval_{metric_name}", value, step)
271
 
 
303
 
304
  # Log on each process the small summary:
305
  logger = logging.getLogger(__name__)
 
 
 
 
306
 
307
  # Set the verbosity to info of the Transformers logger (on main process only):
308
  logger.info(f"Training/evaluation parameters {training_args}")
309
 
310
  # Set seed before initializing model.
 
311
  set_seed(training_args.seed)
312
 
313
  # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
 
346
  if extension == "txt":
347
  extension = "text"
348
  datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
349
+ if data_args.validation_file is None:
350
+ datasets["validation"] = load_dataset(
351
+ extension, data_files=data_files,
352
+ split=f"train[:{data_args.validation_split_percentage}%]",
353
+ cache_dir=model_args.cache_dir,
354
+ )
355
+ datasets["train"] = load_dataset(
356
+ extension, data_files=data_files,
357
+ split=f"train[{data_args.validation_split_percentage}%:]",
358
+ cache_dir=model_args.cache_dir,
359
+ )
360
+ print(datasets)
361
+
362
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
363
  # https://huggingface.co/docs/datasets/loading_datasets.html.
364
 
 
446
  total_length = len(concatenated_examples[list(examples.keys())[0]])
447
  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
448
  # customize this part to your needs.
449
+ if total_length >= max_seq_length:
450
+ total_length = (total_length // max_seq_length) * max_seq_length
451
  # Split by chunks of max_len.
452
  result = {
453
  k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
 
469
  )
470
 
471
  # Enable tensorboard only on the master node
472
+ has_tensorboard = is_tensorboard_available()
473
  if has_tensorboard and jax.process_index() == 0:
474
+ try:
475
+ from flax.metrics.tensorboard import SummaryWriter
476
+
477
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
478
+ except ImportError as ie:
479
+ has_tensorboard = False
480
+ logger.warning(
481
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
482
+ )
483
+ else:
484
+ logger.warning(
485
+ "Unable to display metrics through TensorBoard because the package is not installed: "
486
+ "Please run pip install tensorboard to enable."
487
+ )
488
 
489
  # Data collator
490
  # This one will take care of randomly masking the tokens.
 
494
  rng = jax.random.PRNGKey(training_args.seed)
495
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
496
 
497
+ if model_args.model_name_or_path:
498
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
499
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
500
+ )
501
+ else:
502
+ model = FlaxAutoModelForMaskedLM.from_config(
503
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
504
+ )
505
 
506
  # Store some constant
507
  num_epochs = int(training_args.num_train_epochs)
 
536
  return traverse_util.unflatten_dict(flat_mask)
537
 
538
  # create adam optimizer
539
+ if training_args.adafactor:
540
+ # We use the default parameters here to initialize adafactor,
541
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
542
+ optimizer = optax.adafactor(
543
+ learning_rate=linear_decay_lr_schedule_fn,
544
+ )
545
+ else:
546
+ optimizer = optax.adamw(
547
+ learning_rate=linear_decay_lr_schedule_fn,
548
+ b1=training_args.adam_beta1,
549
+ b2=training_args.adam_beta2,
550
+ eps=training_args.adam_epsilon,
551
+ weight_decay=training_args.weight_decay,
552
+ mask=decay_mask_fn,
553
+ )
554
 
555
  # Setup train state
556
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
557
 
558
  # Define gradient update step fn
559
  def train_step(state, batch, dropout_rng):
 
613
 
614
  train_time = 0
615
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
 
616
  for epoch in epochs:
617
  # ======================== Training ================================
618
  train_start = time.time()
 
627
  train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
628
 
629
  # Gather the indexes for creating the batch and do a training step
630
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
631
  samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
632
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
633
 
 
635
  model_inputs = shard(model_inputs.data)
636
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
637
  train_metrics.append(train_metric)
 
 
 
 
 
 
 
 
638
 
639
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
640
 
641
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
642
+ # Save metrics
643
+ train_metric = jax_utils.unreplicate(train_metric)
644
+ train_time += time.time() - train_start
645
+ if has_tensorboard and jax.process_index() == 0:
646
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
 
 
647
 
648
+ epochs.write(
649
+ f"Step... ({cur_step} | Train Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
650
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651
 
652
+ train_metrics = []
653
+
654
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
655
+ # ======================== Evaluating ==============================
656
+ num_eval_samples = len(tokenized_datasets["validation"])
657
+ eval_samples_idx = jnp.arange(num_eval_samples)
658
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
659
+
660
+ eval_metrics = []
661
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
662
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
663
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
664
+
665
+ # Model forward
666
+ model_inputs = shard(model_inputs.data)
667
+ metrics = p_eval_step(state.params, model_inputs)
668
+ eval_metrics.append(metrics)
669
+
670
+ # normalize eval metrics
671
+ eval_metrics = get_metrics(eval_metrics)
672
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
673
+ eval_normalizer = eval_metrics.pop("normalizer")
674
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
675
+
676
+ # Update progress bar
677
+ epochs.desc = f"Step... ({cur_step} | Val Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
678
+
679
+ # Save metrics
680
+ if has_tensorboard and jax.process_index() == 0:
681
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
682
+
683
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
684
+ # save checkpoint after each epoch and push checkpoint to the hub
685
+ if jax.process_index() == 0:
686
+ step_output_dir = f"checkpoint_{cur_step}"
687
+ os.mkdir(step_output_dir)
688
+ print(f"Saving weights, optimizer state and logs of step {cur_step} at {step_output_dir}")
689
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
690
+ model.save_pretrained(
691
+ step_output_dir,
692
+ params=params,
693
+ push_to_hub=training_args.push_to_hub,
694
+ commit_message=f"Saving weights and logs of step {cur_step}",
695
+ )
696
+ with open("opt_state.msgpack", "wb") as f:
697
+ f.write(to_bytes(state.opt_state))
run_mlm_flax_old.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=masked-lm
22
+ """
23
+ import logging
24
+ import os
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+
29
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
30
+ from pathlib import Path
31
+ from typing import Dict, List, Optional, Tuple
32
+
33
+ import numpy as np
34
+ from datasets import load_dataset
35
+ from tqdm import tqdm
36
+
37
+ import flax
38
+ import jax
39
+ import jax.numpy as jnp
40
+ import optax
41
+ from flax import jax_utils, traverse_util
42
+ from flax.training import train_state
43
+ from flax.training.common_utils import get_metrics, onehot, shard
44
+ from transformers import (
45
+ CONFIG_MAPPING,
46
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
47
+ AutoConfig,
48
+ AutoTokenizer,
49
+ FlaxAutoModelForMaskedLM,
50
+ HfArgumentParser,
51
+ PreTrainedTokenizerBase,
52
+ TensorType,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ set_seed,
56
+ )
57
+
58
+
59
+ # Cache the result
60
+ has_tensorboard = is_tensorboard_available()
61
+ if has_tensorboard:
62
+ try:
63
+ from flax.metrics.tensorboard import SummaryWriter
64
+ except ImportError as ie:
65
+ has_tensorboard = False
66
+ print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
67
+
68
+ else:
69
+ print(
70
+ "Unable to display metrics through TensorBoard because the package is not installed: "
71
+ "Please run pip install tensorboard to enable."
72
+ )
73
+
74
+
75
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
76
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
77
+
78
+
79
+ @dataclass
80
+ class ModelArguments:
81
+ """
82
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
83
+ """
84
+
85
+ model_name_or_path: Optional[str] = field(
86
+ default=None,
87
+ metadata={
88
+ "help": "The model checkpoint for weights initialization."
89
+ "Don't set if you want to train a model from scratch."
90
+ },
91
+ )
92
+ model_type: Optional[str] = field(
93
+ default=None,
94
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
95
+ )
96
+ config_name: Optional[str] = field(
97
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
98
+ )
99
+ tokenizer_name: Optional[str] = field(
100
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
101
+ )
102
+ cache_dir: Optional[str] = field(
103
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
104
+ )
105
+ use_fast_tokenizer: bool = field(
106
+ default=True,
107
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
108
+ )
109
+ dtype: Optional[str] = field(
110
+ default="float32",
111
+ metadata={
112
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
113
+ },
114
+ )
115
+
116
+
117
+ @dataclass
118
+ class DataTrainingArguments:
119
+ """
120
+ Arguments pertaining to what data we are going to input our model for training and eval.
121
+ """
122
+
123
+ dataset_name: Optional[str] = field(
124
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
125
+ )
126
+ dataset_config_name: Optional[str] = field(
127
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
128
+ )
129
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
130
+ validation_file: Optional[str] = field(
131
+ default=None,
132
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
133
+ )
134
+ train_ref_file: Optional[str] = field(
135
+ default=None,
136
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
137
+ )
138
+ validation_ref_file: Optional[str] = field(
139
+ default=None,
140
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
141
+ )
142
+ overwrite_cache: bool = field(
143
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
144
+ )
145
+ validation_split_percentage: Optional[int] = field(
146
+ default=5,
147
+ metadata={
148
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
149
+ },
150
+ )
151
+ max_seq_length: Optional[int] = field(
152
+ default=None,
153
+ metadata={
154
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
155
+ "than this will be truncated. Default to the max input length of the model."
156
+ },
157
+ )
158
+ preprocessing_num_workers: Optional[int] = field(
159
+ default=None,
160
+ metadata={"help": "The number of processes to use for the preprocessing."},
161
+ )
162
+ mlm_probability: float = field(
163
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
164
+ )
165
+ pad_to_max_length: bool = field(
166
+ default=False,
167
+ metadata={
168
+ "help": "Whether to pad all samples to `max_seq_length`. "
169
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
170
+ },
171
+ )
172
+ line_by_line: bool = field(
173
+ default=False,
174
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
175
+ )
176
+
177
+ def __post_init__(self):
178
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
179
+ raise ValueError("Need either a dataset name or a training/validation file.")
180
+ else:
181
+ if self.train_file is not None:
182
+ extension = self.train_file.split(".")[-1]
183
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
184
+ if self.validation_file is not None:
185
+ extension = self.validation_file.split(".")[-1]
186
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
187
+
188
+
189
+ @flax.struct.dataclass
190
+ class FlaxDataCollatorForLanguageModeling:
191
+ """
192
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
193
+ are not all of the same length.
194
+
195
+ Args:
196
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
197
+ The tokenizer used for encoding the data.
198
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
199
+ The probability with which to (randomly) mask tokens in the input.
200
+
201
+ .. note::
202
+
203
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
204
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
205
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
206
+ argument :obj:`return_special_tokens_mask=True`.
207
+ """
208
+
209
+ tokenizer: PreTrainedTokenizerBase
210
+ mlm_probability: float = 0.15
211
+
212
+ def __post_init__(self):
213
+ if self.tokenizer.mask_token is None:
214
+ raise ValueError(
215
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
216
+ "You should pass `mlm=False` to train on causal language modeling instead."
217
+ )
218
+
219
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
220
+ # Handle dict or lists with proper padding and conversion to tensor.
221
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
222
+
223
+ # If special token mask has been preprocessed, pop it from the dict.
224
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
225
+
226
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
227
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
228
+ )
229
+ return batch
230
+
231
+ def mask_tokens(
232
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
233
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
234
+ """
235
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
236
+ """
237
+ labels = inputs.copy()
238
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
239
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
240
+ special_tokens_mask = special_tokens_mask.astype("bool")
241
+
242
+ probability_matrix[special_tokens_mask] = 0.0
243
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
244
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
245
+
246
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
247
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
248
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
249
+
250
+ # 10% of the time, we replace masked input tokens with random word
251
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
252
+ indices_random &= masked_indices & ~indices_replaced
253
+
254
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
255
+ inputs[indices_random] = random_words[indices_random]
256
+
257
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
258
+ return inputs, labels
259
+
260
+
261
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
262
+ num_samples = len(samples_idx)
263
+ samples_to_remove = num_samples % batch_size
264
+
265
+ if samples_to_remove != 0:
266
+ samples_idx = samples_idx[:-samples_to_remove]
267
+ sections_split = num_samples // batch_size
268
+ batch_idx = np.split(samples_idx, sections_split)
269
+ return batch_idx
270
+
271
+
272
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
273
+ summary_writer.scalar("train_time", train_time, step)
274
+
275
+ train_metrics = get_metrics(train_metrics)
276
+ for key, vals in train_metrics.items():
277
+ tag = f"train_{key}"
278
+ for i, val in enumerate(vals):
279
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
280
+
281
+ for metric_name, value in eval_metrics.items():
282
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
283
+
284
+
285
+ if __name__ == "__main__":
286
+ # See all possible arguments in src/transformers/training_args.py
287
+ # or by passing the --help flag to this script.
288
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
289
+
290
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
291
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
292
+ # If we pass only one argument to the script and it's the path to a json file,
293
+ # let's parse it to get our arguments.
294
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
295
+ else:
296
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
297
+
298
+ if (
299
+ os.path.exists(training_args.output_dir)
300
+ and os.listdir(training_args.output_dir)
301
+ and training_args.do_train
302
+ and not training_args.overwrite_output_dir
303
+ ):
304
+ raise ValueError(
305
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
306
+ "Use --overwrite_output_dir to overcome."
307
+ )
308
+
309
+ # Setup logging
310
+ logging.basicConfig(
311
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
312
+ level="NOTSET",
313
+ datefmt="[%X]",
314
+ )
315
+
316
+ # Log on each process the small summary:
317
+ logger = logging.getLogger(__name__)
318
+ logger.warning(
319
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
320
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
321
+ )
322
+
323
+ # Set the verbosity to info of the Transformers logger (on main process only):
324
+ logger.info(f"Training/evaluation parameters {training_args}")
325
+
326
+ # Set seed before initializing model.
327
+ training_args.seed = 42
328
+ set_seed(training_args.seed)
329
+
330
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
331
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
332
+ # (the dataset will be downloaded automatically from the datasets Hub).
333
+ #
334
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
335
+ # 'text' is found. You can easily tweak this behavior (see below).
336
+ #
337
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
338
+ # download the dataset.
339
+ if data_args.dataset_name is not None:
340
+ # Downloading and loading a dataset from the hub.
341
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
342
+
343
+ if "validation" not in datasets.keys():
344
+ datasets["validation"] = load_dataset(
345
+ data_args.dataset_name,
346
+ data_args.dataset_config_name,
347
+ split=f"train[:{data_args.validation_split_percentage}%]",
348
+ cache_dir=model_args.cache_dir,
349
+ )
350
+ datasets["train"] = load_dataset(
351
+ data_args.dataset_name,
352
+ data_args.dataset_config_name,
353
+ split=f"train[{data_args.validation_split_percentage}%:]",
354
+ cache_dir=model_args.cache_dir,
355
+ )
356
+ else:
357
+ data_files = {}
358
+ if data_args.train_file is not None:
359
+ data_files["train"] = data_args.train_file
360
+ if data_args.validation_file is not None:
361
+ data_files["validation"] = data_args.validation_file
362
+ extension = data_args.train_file.split(".")[-1]
363
+ if extension == "txt":
364
+ extension = "text"
365
+ datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
366
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
367
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
368
+
369
+ # Load pretrained model and tokenizer
370
+
371
+ # Distributed training:
372
+ # The .from_pretrained methods guarantee that only one local process can concurrently
373
+ # download model & vocab.
374
+ if model_args.config_name:
375
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
376
+ elif model_args.model_name_or_path:
377
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
378
+ else:
379
+ config = CONFIG_MAPPING[model_args.model_type]()
380
+ logger.warning("You are instantiating a new config instance from scratch.")
381
+
382
+ if model_args.tokenizer_name:
383
+ tokenizer = AutoTokenizer.from_pretrained(
384
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
385
+ )
386
+ elif model_args.model_name_or_path:
387
+ tokenizer = AutoTokenizer.from_pretrained(
388
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
389
+ )
390
+ else:
391
+ raise ValueError(
392
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
393
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
394
+ )
395
+
396
+ # Preprocessing the datasets.
397
+ # First we tokenize all the texts.
398
+ if training_args.do_train:
399
+ column_names = datasets["train"].column_names
400
+ else:
401
+ column_names = datasets["validation"].column_names
402
+ text_column_name = "text" if "text" in column_names else column_names[0]
403
+
404
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
405
+
406
+ if data_args.line_by_line:
407
+ # When using line_by_line, we just tokenize each nonempty line.
408
+ padding = "max_length" if data_args.pad_to_max_length else False
409
+
410
+ def tokenize_function(examples):
411
+ # Remove empty lines
412
+ examples = [line for line in examples if len(line) > 0 and not line.isspace()]
413
+ return tokenizer(
414
+ examples,
415
+ return_special_tokens_mask=True,
416
+ padding=padding,
417
+ truncation=True,
418
+ max_length=max_seq_length,
419
+ )
420
+
421
+ tokenized_datasets = datasets.map(
422
+ tokenize_function,
423
+ input_columns=[text_column_name],
424
+ batched=True,
425
+ num_proc=data_args.preprocessing_num_workers,
426
+ remove_columns=column_names,
427
+ load_from_cache_file=not data_args.overwrite_cache,
428
+ )
429
+
430
+ else:
431
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
432
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
433
+ # efficient when it receives the `special_tokens_mask`.
434
+ def tokenize_function(examples):
435
+ return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
436
+
437
+ tokenized_datasets = datasets.map(
438
+ tokenize_function,
439
+ batched=True,
440
+ num_proc=data_args.preprocessing_num_workers,
441
+ remove_columns=column_names,
442
+ load_from_cache_file=not data_args.overwrite_cache,
443
+ )
444
+
445
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
446
+ # max_seq_length.
447
+ def group_texts(examples):
448
+ # Concatenate all texts.
449
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
450
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
451
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
452
+ # customize this part to your needs.
453
+ total_length = (total_length // max_seq_length) * max_seq_length
454
+ # Split by chunks of max_len.
455
+ result = {
456
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
457
+ for k, t in concatenated_examples.items()
458
+ }
459
+ return result
460
+
461
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
462
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
463
+ # might be slower to preprocess.
464
+ #
465
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
466
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
467
+ tokenized_datasets = tokenized_datasets.map(
468
+ group_texts,
469
+ batched=True,
470
+ num_proc=data_args.preprocessing_num_workers,
471
+ load_from_cache_file=not data_args.overwrite_cache,
472
+ )
473
+
474
+ # Enable tensorboard only on the master node
475
+ if has_tensorboard and jax.process_index() == 0:
476
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
477
+
478
+ # Data collator
479
+ # This one will take care of randomly masking the tokens.
480
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
481
+
482
+ # Initialize our training
483
+ rng = jax.random.PRNGKey(training_args.seed)
484
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
485
+
486
+ model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
487
+
488
+ # Store some constant
489
+ num_epochs = int(training_args.num_train_epochs)
490
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
491
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
492
+
493
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
494
+
495
+ # Create learning rate schedule
496
+ warmup_fn = optax.linear_schedule(
497
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
498
+ )
499
+ decay_fn = optax.linear_schedule(
500
+ init_value=training_args.learning_rate,
501
+ end_value=0,
502
+ transition_steps=num_train_steps - training_args.warmup_steps,
503
+ )
504
+ linear_decay_lr_schedule_fn = optax.join_schedules(
505
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
506
+ )
507
+
508
+ # We use Optax's "masking" functionality to not apply weight decay
509
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
510
+ # mask boolean with the same structure as the parameters.
511
+ # The mask is True for parameters that should be decayed.
512
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
513
+ # For other models, one should correct the layer norm parameter naming
514
+ # accordingly.
515
+ def decay_mask_fn(params):
516
+ flat_params = traverse_util.flatten_dict(params)
517
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
518
+ return traverse_util.unflatten_dict(flat_mask)
519
+
520
+ # create adam optimizer
521
+ adamw = optax.adamw(
522
+ learning_rate=linear_decay_lr_schedule_fn,
523
+ b1=training_args.adam_beta1,
524
+ b2=training_args.adam_beta2,
525
+ eps=1e-8,
526
+ weight_decay=training_args.weight_decay,
527
+ mask=decay_mask_fn,
528
+ )
529
+
530
+ # Setup train state
531
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
532
+
533
+ # Define gradient update step fn
534
+ def train_step(state, batch, dropout_rng):
535
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
536
+
537
+ def loss_fn(params):
538
+ labels = batch.pop("labels")
539
+
540
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
541
+
542
+ # compute loss, ignore padded input tokens
543
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
544
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
545
+
546
+ # take average
547
+ loss = loss.sum() / label_mask.sum()
548
+
549
+ return loss
550
+
551
+ grad_fn = jax.value_and_grad(loss_fn)
552
+ loss, grad = grad_fn(state.params)
553
+ grad = jax.lax.pmean(grad, "batch")
554
+ new_state = state.apply_gradients(grads=grad)
555
+
556
+ metrics = jax.lax.pmean(
557
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
558
+ )
559
+
560
+ return new_state, metrics, new_dropout_rng
561
+
562
+ # Create parallel version of the train step
563
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
564
+
565
+ # Define eval fn
566
+ def eval_step(params, batch):
567
+ labels = batch.pop("labels")
568
+
569
+ logits = model(**batch, params=params, train=False)[0]
570
+
571
+ # compute loss, ignore padded input tokens
572
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
573
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
574
+
575
+ # compute accuracy
576
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
577
+
578
+ # summarize metrics
579
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
580
+ metrics = jax.lax.psum(metrics, axis_name="batch")
581
+
582
+ return metrics
583
+
584
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
585
+
586
+ # Replicate the train state on each device
587
+ state = jax_utils.replicate(state)
588
+
589
+ train_time = 0
590
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
591
+ save_checkpoint=True
592
+ for epoch in epochs:
593
+ # ======================== Training ================================
594
+ train_start = time.time()
595
+ train_metrics = []
596
+
597
+ # Create sampling rng
598
+ rng, input_rng = jax.random.split(rng)
599
+
600
+ # Generate an epoch by shuffling sampling indices from the train dataset
601
+ num_train_samples = len(tokenized_datasets["train"])
602
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
603
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
604
+
605
+ # Gather the indexes for creating the batch and do a training step
606
+ for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
607
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
608
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
609
+
610
+ # Model forward
611
+ model_inputs = shard(model_inputs.data)
612
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
613
+ train_metrics.append(train_metric)
614
+ if save_checkpoint and (train_metric['loss'] < 1.).all():
615
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
616
+ model.save_pretrained(
617
+ '/home/khandelia1000/checkpoints/',
618
+ params=params,
619
+ push_to_hub=False
620
+ )
621
+ save_checkpoint = False
622
+
623
+ train_time += time.time() - train_start
624
+
625
+ epochs.write(
626
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
627
+ )
628
+
629
+ # ======================== Evaluating ==============================
630
+ num_eval_samples = len(tokenized_datasets["validation"])
631
+ eval_samples_idx = jnp.arange(num_eval_samples)
632
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
633
+
634
+ eval_metrics = []
635
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
636
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
637
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
638
+
639
+ # Model forward
640
+ model_inputs = shard(model_inputs.data)
641
+ metrics = p_eval_step(state.params, model_inputs)
642
+ eval_metrics.append(metrics)
643
+
644
+ # normalize eval metrics
645
+ eval_metrics = get_metrics(eval_metrics)
646
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
647
+ eval_normalizer = eval_metrics.pop("normalizer")
648
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
649
+
650
+ # Update progress bar
651
+ epochs.desc = (
652
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
653
+ )
654
+
655
+ # Save metrics
656
+ if has_tensorboard and jax.process_index() == 0:
657
+ cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
658
+ write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
659
+
660
+ # save checkpoint after each epoch and push checkpoint to the hub
661
+ if jax.process_index() == 0:
662
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
663
+ model.save_pretrained(
664
+ training_args.output_dir,
665
+ params=params,
666
+ push_to_hub=training_args.push_to_hub,
667
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
668
+ )
run_mlm_flax_stream.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=masked-lm
22
+ """
23
+ import logging
24
+ import os
25
+ import sys
26
+ import time
27
+ from collections import defaultdict
28
+ from dataclasses import dataclass, field
29
+
30
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
31
+ from pathlib import Path
32
+ from typing import Dict, List, Optional, Tuple
33
+
34
+ from utils import keep_devnagri
35
+
36
+ import datasets
37
+ import numpy as np
38
+ from datasets import load_dataset
39
+ from tqdm import tqdm
40
+
41
+ import flax
42
+ import jax
43
+ import jax.numpy as jnp
44
+ import optax
45
+ from flax import jax_utils, traverse_util
46
+ from flax.training import train_state
47
+ from flax.training.common_utils import get_metrics, onehot, shard
48
+ from transformers import (
49
+ CONFIG_MAPPING,
50
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
51
+ AutoConfig,
52
+ AutoTokenizer,
53
+ FlaxAutoModelForMaskedLM,
54
+ HfArgumentParser,
55
+ PreTrainedTokenizerBase,
56
+ TensorType,
57
+ TrainingArguments,
58
+ is_tensorboard_available,
59
+ set_seed,
60
+ )
61
+
62
+
63
+ # if datasets.__version__ <= "1.8.0":
64
+ # raise ValueError("Make sure to upgrade `datasets` to a version >= 1.9.0 to use dataset streaming")
65
+
66
+
67
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
68
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
69
+
70
+
71
+ @dataclass
72
+ class ModelArguments:
73
+ """
74
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
75
+ """
76
+
77
+ model_name_or_path: Optional[str] = field(
78
+ default=None,
79
+ metadata={
80
+ "help": "The model checkpoint for weights initialization."
81
+ "Don't set if you want to train a model from scratch."
82
+ },
83
+ )
84
+ model_type: Optional[str] = field(
85
+ default=None,
86
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
87
+ )
88
+ config_name: Optional[str] = field(
89
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
90
+ )
91
+ tokenizer_name: Optional[str] = field(
92
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
93
+ )
94
+ cache_dir: Optional[str] = field(
95
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
96
+ )
97
+ use_fast_tokenizer: bool = field(
98
+ default=True,
99
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
100
+ )
101
+ dtype: Optional[str] = field(
102
+ default="float32",
103
+ metadata={
104
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
105
+ },
106
+ )
107
+
108
+
109
+ @dataclass
110
+ class DataTrainingArguments:
111
+ """
112
+ Arguments pertaining to what data we are going to input our model for training and eval.
113
+ """
114
+
115
+ dataset_name: Optional[str] = field(
116
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
117
+ )
118
+ dataset_config_name: Optional[str] = field(
119
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
120
+ )
121
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
122
+ validation_file: Optional[str] = field(
123
+ default=None,
124
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
125
+ )
126
+ train_ref_file: Optional[str] = field(
127
+ default=None,
128
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
129
+ )
130
+ validation_ref_file: Optional[str] = field(
131
+ default=None,
132
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
133
+ )
134
+ overwrite_cache: bool = field(
135
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
136
+ )
137
+ validation_split_percentage: Optional[int] = field(
138
+ default=5,
139
+ metadata={
140
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
141
+ },
142
+ )
143
+ max_seq_length: Optional[int] = field(
144
+ default=None,
145
+ metadata={
146
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
147
+ "than this will be truncated. Default to the max input length of the model."
148
+ },
149
+ )
150
+ preprocessing_num_workers: Optional[int] = field(
151
+ default=None,
152
+ metadata={"help": "The number of processes to use for the preprocessing."},
153
+ )
154
+ mlm_probability: float = field(
155
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
156
+ )
157
+ pad_to_max_length: bool = field(
158
+ default=False,
159
+ metadata={
160
+ "help": "Whether to pad all samples to `max_seq_length`. "
161
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
162
+ },
163
+ )
164
+ line_by_line: bool = field(
165
+ default=False,
166
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
167
+ )
168
+ text_column_name: str = field(
169
+ default="text", metadata={"help": "The name of the column to retrieve the training text."}
170
+ )
171
+ shuffle_buffer_size: int = field(
172
+ default=10000, metadata={"help": "The number of examples to pre-load for shuffling."}
173
+ )
174
+ num_train_steps: int = field(default=50000, metadata={"help": "The number of training steps."})
175
+ num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
176
+
177
+ def __post_init__(self):
178
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
179
+ raise ValueError("Need either a dataset name or a training/validation file.")
180
+ else:
181
+ if self.train_file is not None:
182
+ extension = self.train_file.split(".")[-1]
183
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
184
+ if self.validation_file is not None:
185
+ extension = self.validation_file.split(".")[-1]
186
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
187
+
188
+
189
+ @flax.struct.dataclass
190
+ class FlaxDataCollatorForLanguageModeling:
191
+ """
192
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
193
+ are not all of the same length.
194
+
195
+ Args:
196
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
197
+ The tokenizer used for encoding the data.
198
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
199
+ The probability with which to (randomly) mask tokens in the input.
200
+
201
+ .. note::
202
+
203
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
204
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
205
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
206
+ argument :obj:`return_special_tokens_mask=True`.
207
+ """
208
+
209
+ tokenizer: PreTrainedTokenizerBase
210
+ mlm_probability: float = 0.15
211
+
212
+ def __post_init__(self):
213
+ if self.tokenizer.mask_token is None:
214
+ raise ValueError(
215
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
216
+ "You should pass `mlm=False` to train on causal language modeling instead."
217
+ )
218
+
219
+ def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
220
+ # Handle dict or lists with proper padding and conversion to tensor.
221
+ batch = self.tokenizer.pad(examples, return_tensors=TensorType.NUMPY)
222
+
223
+ # If special token mask has been preprocessed, pop it from the dict.
224
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
225
+
226
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
227
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
228
+ )
229
+ return batch
230
+
231
+ def mask_tokens(
232
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
233
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
234
+ """
235
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
236
+ """
237
+ labels = inputs.copy()
238
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
239
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
240
+ special_tokens_mask = special_tokens_mask.astype("bool")
241
+
242
+ probability_matrix[special_tokens_mask] = 0.0
243
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
244
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
245
+
246
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
247
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
248
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
249
+
250
+ # 10% of the time, we replace masked input tokens with random word
251
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
252
+ indices_random &= masked_indices & ~indices_replaced
253
+
254
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
255
+ inputs[indices_random] = random_words[indices_random]
256
+
257
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
258
+ return inputs, labels
259
+
260
+
261
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
262
+ num_samples = len(samples_idx)
263
+ samples_to_remove = num_samples % batch_size
264
+
265
+ if samples_to_remove != 0:
266
+ samples_idx = samples_idx[:-samples_to_remove]
267
+ sections_split = num_samples // batch_size
268
+ batch_idx = np.split(samples_idx, sections_split)
269
+ return batch_idx
270
+
271
+
272
+ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
273
+ """
274
+ The training iterator is advanced so that after groupifying the samples,
275
+ `num_samples` of length `max_seq_length` are returned.
276
+ """
277
+ num_total_tokens = max_seq_length * num_samples
278
+ samples = defaultdict(list)
279
+
280
+ i = 0
281
+ doc_count = 0
282
+ while i < num_total_tokens:
283
+ tokenized_samples = next(train_iterator)
284
+ i += len(tokenized_samples["input_ids"])
285
+ doc_count += 1
286
+
287
+ # concatenate tokenized samples to list
288
+ samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
289
+
290
+
291
+ # Concatenated tokens are split to lists of length `max_seq_length`.
292
+ # Note that remainedr of % max_seq_length are thrown away.
293
+ def group_texts(examples):
294
+ result = {
295
+ k: [t[i : i + max_seq_length] for i in range(0, num_total_tokens, max_seq_length)]
296
+ for k, t in examples.items()
297
+ }
298
+ return result
299
+
300
+ grouped_samples = group_texts(samples)
301
+ return doc_count, grouped_samples
302
+
303
+
304
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
305
+ summary_writer.scalar("train_time", train_time, step)
306
+
307
+ train_metrics = get_metrics(train_metrics)
308
+ for key, vals in train_metrics.items():
309
+ tag = f"train_{key}"
310
+ for i, val in enumerate(vals):
311
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
312
+
313
+
314
+ def write_eval_metric(summary_writer, eval_metrics, step):
315
+ for metric_name, value in eval_metrics.items():
316
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
317
+
318
+
319
+ if __name__ == "__main__":
320
+ # See all possible arguments in src/transformers/training_args.py
321
+ # or by passing the --help flag to this script.
322
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
323
+
324
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
325
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
326
+ # If we pass only one argument to the script and it's the path to a json file,
327
+ # let's parse it to get our arguments.
328
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
329
+ else:
330
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
331
+
332
+ if (
333
+ os.path.exists(training_args.output_dir)
334
+ and os.listdir(training_args.output_dir)
335
+ and training_args.do_train
336
+ and not training_args.overwrite_output_dir
337
+ ):
338
+ raise ValueError(
339
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
340
+ "Use --overwrite_output_dir to overcome."
341
+ )
342
+
343
+ # Setup logging
344
+ logging.basicConfig(
345
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
346
+ level="INFO",
347
+ datefmt="[%X]",
348
+ )
349
+
350
+ # Log on each process the small summary:
351
+ logger = logging.getLogger(__name__)
352
+ logger.warning(
353
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
354
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
355
+ )
356
+
357
+ # Set the verbosity to info of the Transformers logger (on main process only):
358
+ logger.info(f"Training/evaluation parameters {training_args}")
359
+
360
+ # Set seed before initializing model.
361
+ set_seed(training_args.seed)
362
+
363
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
364
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
365
+ # (the dataset will be downloaded automatically from the datasets Hub).
366
+ #
367
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
368
+ # 'text' is found. You can easily tweak this behavior (see below).
369
+ if data_args.dataset_name is not None:
370
+ # Downloading and loading a dataset from the hub.
371
+ dataset = load_dataset(
372
+ data_args.dataset_name,
373
+ data_args.dataset_config_name,
374
+ cache_dir=model_args.cache_dir,
375
+ streaming=True,
376
+ split="train",
377
+ )
378
+
379
+ if model_args.config_name:
380
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
381
+ elif model_args.model_name_or_path:
382
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
383
+ else:
384
+ config = CONFIG_MAPPING[model_args.model_type]()
385
+ logger.warning("You are instantiating a new config instance from scratch.")
386
+
387
+ if model_args.tokenizer_name:
388
+ tokenizer = AutoTokenizer.from_pretrained(
389
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
390
+ )
391
+ elif model_args.model_name_or_path:
392
+ tokenizer = AutoTokenizer.from_pretrained(
393
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
394
+ )
395
+ else:
396
+ raise ValueError(
397
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
398
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
399
+ )
400
+
401
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
402
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
403
+ # efficient when it receives the `special_tokens_mask`.
404
+ def tokenize_function(examples):
405
+ return tokenizer(examples[data_args.text_column_name], return_special_tokens_mask=True)
406
+
407
+ cleaned_dataset = dataset.map(
408
+ keep_devnagri,
409
+ batched=False,
410
+ )
411
+ tokenized_datasets = cleaned_dataset.map(
412
+ tokenize_function,
413
+ batched=True,
414
+ )
415
+
416
+ shuffle_seed = training_args.seed
417
+ tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
418
+
419
+ has_tensorboard = is_tensorboard_available()
420
+ if has_tensorboard and jax.process_index() == 0:
421
+ try:
422
+ from flax.metrics.tensorboard import SummaryWriter
423
+ except ImportError as ie:
424
+ has_tensorboard = False
425
+ logger.warning(
426
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
427
+ )
428
+
429
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
430
+
431
+ # Data collator
432
+ # This one will take care of randomly masking the tokens.
433
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
434
+
435
+ # Initialize our training
436
+ rng = jax.random.PRNGKey(training_args.seed)
437
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
438
+
439
+ if model_args.model_name_or_path:
440
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
441
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
442
+ )
443
+ else:
444
+ model = FlaxAutoModelForMaskedLM.from_config(
445
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
446
+ )
447
+ if jax.device_count() < 8:
448
+ print('Number of device as per jax device count is {}. Press Enter to continue'.format(jax.device_count()))
449
+
450
+ # Store some constant
451
+ num_epochs = int(training_args.num_train_epochs)
452
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
453
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
454
+
455
+ # define number steps per stream epoch
456
+ dataset_doc_count = 18507273
457
+ num_train_steps = ((dataset_doc_count//train_batch_size) + 1) * num_epochs * 2
458
+
459
+ # Create learning rate schedule
460
+ warmup_fn = optax.linear_schedule(
461
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
462
+ )
463
+ decay_fn = optax.linear_schedule(
464
+ init_value=training_args.learning_rate,
465
+ end_value=0,
466
+ transition_steps=num_train_steps - training_args.warmup_steps,
467
+ )
468
+ linear_decay_lr_schedule_fn = optax.join_schedules(
469
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
470
+ )
471
+
472
+ # We use Optax's "masking" functionality to not apply weight decay
473
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
474
+ # mask boolean with the same structure as the parameters.
475
+ # The mask is True for parameters that should be decayed.
476
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
477
+ # For other models, one should correct the layer norm parameter naming
478
+ # accordingly.
479
+ def decay_mask_fn(params):
480
+ flat_params = traverse_util.flatten_dict(params)
481
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
482
+ return traverse_util.unflatten_dict(flat_mask)
483
+
484
+ # create adam optimizer
485
+ adamw = optax.adamw(
486
+ learning_rate=linear_decay_lr_schedule_fn,
487
+ b1=training_args.adam_beta1,
488
+ b2=training_args.adam_beta2,
489
+ eps=training_args.adam_epsilon,
490
+ weight_decay=training_args.weight_decay,
491
+ mask=decay_mask_fn,
492
+ )
493
+
494
+ # Setup train state
495
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
496
+
497
+ # Define gradient update step fn
498
+ def train_step(state, batch, dropout_rng):
499
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
500
+
501
+ def loss_fn(params):
502
+ labels = batch.pop("labels")
503
+
504
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
505
+
506
+ # compute loss, ignore padded input tokens
507
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
508
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
509
+
510
+ # take average
511
+ loss = loss.sum() / label_mask.sum()
512
+
513
+ return loss
514
+
515
+ grad_fn = jax.value_and_grad(loss_fn)
516
+ loss, grad = grad_fn(state.params)
517
+ grad = jax.lax.pmean(grad, "batch")
518
+ new_state = state.apply_gradients(grads=grad)
519
+
520
+ metrics = jax.lax.pmean(
521
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
522
+ )
523
+
524
+ return new_state, metrics, new_dropout_rng
525
+
526
+ # Create parallel version of the train step
527
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
528
+
529
+ # Define eval fn
530
+ def eval_step(params, batch):
531
+ labels = batch.pop("labels")
532
+
533
+ logits = model(**batch, params=params, train=False)[0]
534
+
535
+ # compute loss, ignore padded input tokens
536
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
537
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
538
+
539
+ # compute accuracy
540
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
541
+
542
+ # summarize metrics
543
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
544
+ metrics = jax.lax.psum(metrics, axis_name="batch")
545
+
546
+ return metrics
547
+
548
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
549
+
550
+ # Replicate the train state on each device
551
+ state = jax_utils.replicate(state)
552
+
553
+ train_time = 0
554
+ train_start = time.time()
555
+ train_metrics = []
556
+ eval_metrics = []
557
+
558
+ training_iter = iter(tokenized_datasets)
559
+
560
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
561
+ doc_count, eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
562
+
563
+ steps = tqdm(range(num_train_steps), desc="Training...", position=0)
564
+ docs_progress_bar = tqdm(range(dataset_doc_count * num_epochs), desc="Docs Processed...", position=0)
565
+ for step in range(num_train_steps):
566
+ # ======================== Training ================================
567
+ try:
568
+ doc_count, samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
569
+
570
+ except StopIteration:
571
+ # Once the end of the dataset stream is reached, the training iterator
572
+ # is reinitialized and reshuffled and a new eval dataset is randomely chosen.
573
+ shuffle_seed += 1
574
+ tokenized_datasets.set_epoch(shuffle_seed)
575
+
576
+ training_iter = iter(tokenized_datasets)
577
+
578
+ _, eval_dataset = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
579
+ doc_count, samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
580
+
581
+
582
+ # process input samples
583
+ model_inputs = data_collator(samples)
584
+
585
+ # Model forward
586
+ model_inputs = shard(model_inputs.data)
587
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
588
+
589
+ train_metrics.append(train_metric)
590
+
591
+ if step % training_args.logging_steps == 0 and step > 0:
592
+ steps.write(
593
+ f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
594
+ )
595
+ train_time += time.time() - train_start
596
+ if has_tensorboard and jax.process_index() == 0:
597
+ write_train_metric(summary_writer, train_metrics, train_time, step)
598
+ train_metrics = []
599
+
600
+ # ======================== Evaluating ==============================
601
+ if step % training_args.eval_steps == 0 and step > 0:
602
+ eval_samples_idx = jnp.arange(data_args.num_eval_samples)
603
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
604
+
605
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
606
+ # process input samples
607
+ batch_eval_samples = {k: [v[idx] for idx in batch_idx] for k, v in eval_samples.items()}
608
+ model_inputs = data_collator(batch_eval_samples)
609
+
610
+ # Model forward
611
+ model_inputs = shard(model_inputs.data)
612
+ metrics = p_eval_step(state.params, model_inputs)
613
+ eval_metrics.append(metrics)
614
+
615
+ # normalize eval metrics
616
+ eval_metrics = get_metrics(eval_metrics)
617
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
618
+ eval_normalizer = eval_metrics.pop("normalizer")
619
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
620
+
621
+ # Update progress bar
622
+ steps.desc = f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
623
+
624
+ if has_tensorboard and jax.process_index() == 0:
625
+ write_eval_metric(summary_writer, eval_metrics, step)
626
+ eval_metrics = []
627
+
628
+ # save checkpoint after each epoch and push checkpoint to the hub
629
+ if jax.process_index() == 0:
630
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
631
+ model.save_pretrained(
632
+ training_args.output_dir,
633
+ params=params,
634
+ push_to_hub=training_args.push_to_hub,
635
+ commit_message=f"Saving weights and logs of step {step+1}",
636
+ )
637
+
638
+ # update tqdm bar
639
+ docs_progress_bar.update(doc_count)
640
+ steps.update(1)
run_stream.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python3 -c "import jax; print(jax.devices())"
2
+ ./run_mlm_flax_stream.py \
3
+ --output_dir="${MODEL_DIR}" \
4
+ --model_type="roberta" \
5
+ --config_name="${MODEL_DIR}" \
6
+ --tokenizer_name="${MODEL_DIR}" \
7
+ --dataset_name="mc4" \
8
+ --dataset_config_name="hi" \
9
+ --max_seq_length="256" \
10
+ --per_device_train_batch_size="128" \
11
+ --per_device_eval_batch_size="128" \
12
+ --learning_rate="3e-4" \
13
+ --warmup_steps="1000" \
14
+ --overwrite_output_dir \
15
+ --adam_beta1="0.9" \
16
+ --adam_beta2="0.98" \
17
+ --num_train_steps="10000" \
18
+ --num_eval_samples="5000" \
19
+ --logging_steps="250" \
20
+ --eval_steps="1000"
train_tokenizer.py CHANGED
@@ -1,19 +1,31 @@
1
  #!/usr/bin/env python3
2
  from datasets import load_dataset
3
  from datasets import load_from_disk
4
- from tokenizers import ByteLevelBPETokenizer
5
  from tqdm import tqdm
6
- # load dataset
7
- # dataset = load_dataset("oscar", "unshuffled_deduplicated_hi", split="train")
8
 
9
- dataset = load_from_disk("/home/rtx/work/dk/hf/vo")
 
 
 
10
 
11
  # Instantiate tokenizer
12
- tokenizer = ByteLevelBPETokenizer(add_prefix_space=True)
13
 
14
  def batch_iterator(batch_size=100_000):
15
- for i in range(0, len(dataset), batch_size):
16
- yield dataset[i: i + batch_size]["text"]
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Customized training
19
  tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=50, special_tokens=[
@@ -22,8 +34,8 @@ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=
22
  "</s>",
23
  "<unk>",
24
  "<mask>",
25
- ])
26
 
27
  # Save files to disk
28
- tokenizer.save("./tokenizer.json")
29
 
 
1
  #!/usr/bin/env python3
2
  from datasets import load_dataset
3
  from datasets import load_from_disk
4
+ from tokenizers import ByteLevelBPETokenizer, SentencePieceBPETokenizer
5
  from tqdm import tqdm
 
 
6
 
7
+ from utils import keep_devnagri
8
+
9
+ # load dataset
10
+ dataset = load_dataset("mc4", "hi", split="train", streaming=True)
11
 
12
  # Instantiate tokenizer
13
+ tokenizer = SentencePieceBPETokenizer(add_prefix_space=True)
14
 
15
  def batch_iterator(batch_size=100_000):
16
+ # total docs: 1,85,07,273
17
+ text_ls = []
18
+
19
+ for example in dataset:
20
+ devnagari_text, is_just_punctuation = keep_devnagri(example['text'])
21
+ if not is_just_punctuation:
22
+ text_ls.append(devnagari_text)
23
+ if len(text_ls) == batch_size:
24
+ yield text_ls
25
+ text_ls = []
26
+ if len(text_ls) > 0:
27
+ yield text_ls
28
+
29
 
30
  # Customized training
31
  tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=50, special_tokens=[
 
34
  "</s>",
35
  "<unk>",
36
  "<mask>",
37
+ ], )
38
 
39
  # Save files to disk
40
+ tokenizer.save("/home/khandelia1000/tokenizer.json")
41
 
utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import regex as re
2
+ import string
3
+
4
+ def keep_devnagri(document:str):
5
+ """
6
+ Remove all non Devnagri characters from the text.
7
+ Code adapted from https://huggingface.co/flax-community/roberta-base-mr/blob/64d2c745f264f09c3d5b678a718746b2613887db/mr_clean_text.py
8
+
9
+ @param text: str Text to be cleaned
10
+ @return: Union[str, bool]
11
+ """
12
+ text = document['text']
13
+ pattern = r'[\p{Devanagari}0-9।\s\.\!]+'
14
+
15
+ # regex pattern for all puntuation symbols
16
+ punctuation_regex = re.compile("[" + re.escape(string.punctuation) + string.digits + "|" + "]")
17
+
18
+ # keep only the text which is in devnagari script
19
+ cleaned = "".join([tok.group() for tok in re.finditer(pattern, text)])
20
+
21
+ # remove any extra space between words
22
+ cleaned = re.sub(r"[ ]+", " ", cleaned)
23
+
24
+ # identify if the clean text only consists of punctuation
25
+ is_just_punctuation = len(re.sub(punctuation_regex, "", cleaned)) == 0
26
+
27
+ # to handle the tokenizer as empty string may cause issues
28
+ # also this only happens for 5 out of 10000 docs, should not
29
+ # affect the results
30
+ if is_just_punctuation:
31
+ document['text'] = " "
32
+ else:
33
+ document['text'] = cleaned
34
+ return document