sakares commited on
Commit
51a46c2
1 Parent(s): 1176e2d

adjust running script

Browse files
Files changed (2) hide show
  1. run.sh +2 -3
  2. run_mlm_flax.py +33 -28
run.sh CHANGED
@@ -9,11 +9,10 @@ python3 run_mlm_flax.py \
9
  --dataset_config_name="unshuffled_deduplicated_th" \
10
  --max_seq_length="128" \
11
  --preprocessing_num_workers="64" \
12
- --per_device_train_batch_size="64" \
13
- --per_device_eval_batch_size="64" \
14
  --learning_rate="2e-4" \
15
  --warmup_steps="1000" \
16
  --overwrite_output_dir \
17
  --num_train_epochs="8" \
18
- --dtype="bfloat16" \
19
  --push_to_hub
 
9
  --dataset_config_name="unshuffled_deduplicated_th" \
10
  --max_seq_length="128" \
11
  --preprocessing_num_workers="64" \
12
+ --per_device_train_batch_size="32" \
13
+ --per_device_eval_batch_size="32" \
14
  --learning_rate="2e-4" \
15
  --warmup_steps="1000" \
16
  --overwrite_output_dir \
17
  --num_train_epochs="8" \
 
18
  --push_to_hub
run_mlm_flax.py CHANGED
@@ -56,22 +56,6 @@ from transformers import (
56
  )
57
 
58
 
59
- # Cache the result
60
- has_tensorboard = is_tensorboard_available()
61
- if has_tensorboard:
62
- try:
63
- from flax.metrics.tensorboard import SummaryWriter
64
- except ImportError as ie:
65
- has_tensorboard = False
66
- print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
67
-
68
- else:
69
- print(
70
- "Unable to display metrics through TensorBoard because the package is not installed: "
71
- "Please run pip install tensorboard to enable."
72
- )
73
-
74
-
75
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
76
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
77
 
@@ -269,7 +253,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
269
  return batch_idx
270
 
271
 
272
- def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
273
  summary_writer.scalar("train_time", train_time, step)
274
 
275
  train_metrics = get_metrics(train_metrics)
@@ -278,6 +262,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
278
  for i, val in enumerate(vals):
279
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
280
 
 
 
281
  for metric_name, value in eval_metrics.items():
282
  summary_writer.scalar(f"eval_{metric_name}", value, step)
283
 
@@ -315,10 +301,6 @@ if __name__ == "__main__":
315
 
316
  # Log on each process the small summary:
317
  logger = logging.getLogger(__name__)
318
- logger.warning(
319
- f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
320
- + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
321
- )
322
 
323
  # Set the verbosity to info of the Transformers logger (on main process only):
324
  logger.info(f"Training/evaluation parameters {training_args}")
@@ -471,8 +453,22 @@ if __name__ == "__main__":
471
  )
472
 
473
  # Enable tensorboard only on the master node
 
474
  if has_tensorboard and jax.process_index() == 0:
475
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
 
 
 
 
 
 
 
 
 
 
 
 
 
476
 
477
  # Data collator
478
  # This one will take care of randomly masking the tokens.
@@ -601,7 +597,7 @@ if __name__ == "__main__":
601
  train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
602
 
603
  # Gather the indexes for creating the batch and do a training step
604
- for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
605
  samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
606
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
607
 
@@ -610,11 +606,20 @@ if __name__ == "__main__":
610
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
611
  train_metrics.append(train_metric)
612
 
613
- train_time += time.time() - train_start
614
 
615
- epochs.write(
616
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
617
- )
 
 
 
 
 
 
 
 
 
618
 
619
  # ======================== Evaluating ==============================
620
  num_eval_samples = len(tokenized_datasets["validation"])
@@ -645,7 +650,7 @@ if __name__ == "__main__":
645
  # Save metrics
646
  if has_tensorboard and jax.process_index() == 0:
647
  cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
648
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
649
 
650
  # save checkpoint after each epoch and push checkpoint to the hub
651
  if jax.process_index() == 0:
 
56
  )
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
60
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
61
 
 
253
  return batch_idx
254
 
255
 
256
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
257
  summary_writer.scalar("train_time", train_time, step)
258
 
259
  train_metrics = get_metrics(train_metrics)
 
262
  for i, val in enumerate(vals):
263
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
264
 
265
+
266
+ def write_eval_metric(summary_writer, eval_metrics, step):
267
  for metric_name, value in eval_metrics.items():
268
  summary_writer.scalar(f"eval_{metric_name}", value, step)
269
 
 
301
 
302
  # Log on each process the small summary:
303
  logger = logging.getLogger(__name__)
 
 
 
 
304
 
305
  # Set the verbosity to info of the Transformers logger (on main process only):
306
  logger.info(f"Training/evaluation parameters {training_args}")
 
453
  )
454
 
455
  # Enable tensorboard only on the master node
456
+ has_tensorboard = is_tensorboard_available()
457
  if has_tensorboard and jax.process_index() == 0:
458
+ try:
459
+ from flax.metrics.tensorboard import SummaryWriter
460
+
461
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
462
+ except ImportError as ie:
463
+ has_tensorboard = False
464
+ logger.warning(
465
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
466
+ )
467
+ else:
468
+ logger.warning(
469
+ "Unable to display metrics through TensorBoard because the package is not installed: "
470
+ "Please run pip install tensorboard to enable."
471
+ )
472
 
473
  # Data collator
474
  # This one will take care of randomly masking the tokens.
 
597
  train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
598
 
599
  # Gather the indexes for creating the batch and do a training step
600
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
601
  samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
602
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
603
 
 
606
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
607
  train_metrics.append(train_metric)
608
 
609
+ cur_step = epoch * num_train_samples + step
610
 
611
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
612
+ # Save metrics
613
+ train_metric = jax_utils.unreplicate(train_metric)
614
+ train_time += time.time() - train_start
615
+ if has_tensorboard and jax.process_index() == 0:
616
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
617
+
618
+ epochs.write(
619
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
620
+ )
621
+
622
+ train_metrics = []
623
 
624
  # ======================== Evaluating ==============================
625
  num_eval_samples = len(tokenized_datasets["validation"])
 
650
  # Save metrics
651
  if has_tensorboard and jax.process_index() == 0:
652
  cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
653
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
654
 
655
  # save checkpoint after each epoch and push checkpoint to the hub
656
  if jax.process_index() == 0: