prateekagrawal commited on
Commit
b1b3841
1 Parent(s): 37fa322

Saving weights and logs of step 8

Browse files
Files changed (9) hide show
  1. README.md +0 -0
  2. config.json +0 -0
  3. create_config.py +0 -0
  4. flax_model.msgpack +3 -0
  5. run.sh +5 -2
  6. run.sh.save +16 -0
  7. run_mlm_flax.py +39 -35
  8. tokenizer.json +0 -0
  9. train_tokenizer.py +0 -0
README.md CHANGED
File without changes
config.json CHANGED
File without changes
create_config.py CHANGED
File without changes
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2e26684c7b415b88900d2b10f657004a4262d41aca55f28d52013b051535c43
3
+ size 498796983
run.sh CHANGED
@@ -7,10 +7,13 @@
7
  --dataset_name="oscar" \
8
  --dataset_config_name="unshuffled_deduplicated_it" \
9
  --max_seq_length="128" \
10
- --per_device_train_batch_size="4" \
11
- --per_device_eval_batch_size="4" \
12
  --learning_rate="3e-4" \
13
  --warmup_steps="1000" \
14
  --overwrite_output_dir \
15
  --num_train_epochs="8" \
 
 
 
16
  --push_to_hub
 
7
  --dataset_name="oscar" \
8
  --dataset_config_name="unshuffled_deduplicated_it" \
9
  --max_seq_length="128" \
10
+ --per_device_train_batch_size="1" \
11
+ --per_device_eval_batch_size="1" \
12
  --learning_rate="3e-4" \
13
  --warmup_steps="1000" \
14
  --overwrite_output_dir \
15
  --num_train_epochs="8" \
16
+ --logging_steps="10" \
17
+ --save_steps="8" \
18
+ --eval_steps="15" \
19
  --push_to_hub
run.sh.save ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /usr/bin/env bash
2
+ ./run_mlm_flax.py \
3
+ --output_dir="./" \
4
+ --model_type="roberta" \
5
+ --config_name="./" \
6
+ --tokenizer_name="./" \
7
+ --dataset_name="oscar" \
8
+ --dataset_config_name="unshuffled_deduplicated_it" \
9
+ --max_seq_length="128" \
10
+ --per_device_train_batch_size="4" \
11
+ --per_device_eval_batch_size="4" \
12
+ --learning_rate="3e-4" \
13
+ --warmup_steps="1000" \
14
+ --overwrite_output_dir \
15
+ --num_train_epochs="8" \
16
+ --push_to_hub
run_mlm_flax.py CHANGED
@@ -297,6 +297,10 @@ if __name__ == "__main__":
297
  if extension == "txt":
298
  extension = "text"
299
  datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
 
 
 
 
300
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
301
  # https://huggingface.co/docs/datasets/loading_datasets.html.
302
  # Load pretrained model and tokenizer
@@ -512,7 +516,7 @@ if __name__ == "__main__":
512
  model_inputs = shard(model_inputs.data)
513
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
514
  train_metrics.append(train_metric)
515
- cur_step = epoch * num_train_samples + step
516
  if cur_step % training_args.logging_steps == 0 and cur_step > 0:
517
  # Save metrics
518
  train_metric = jax_utils.unreplicate(train_metric)
@@ -523,37 +527,37 @@ if __name__ == "__main__":
523
  f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
524
  )
525
  train_metrics = []
526
- # ======================== Evaluating ==============================
527
- num_eval_samples = len(tokenized_datasets["validation"])
528
- eval_samples_idx = jnp.arange(num_eval_samples)
529
- eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
530
- eval_metrics = []
531
- for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
532
- samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
533
- model_inputs = data_collator(samples, pad_to_multiple_of=16)
534
- # Model forward
535
- model_inputs = shard(model_inputs.data)
536
- metrics = p_eval_step(state.params, model_inputs)
537
- eval_metrics.append(metrics)
538
- # normalize eval metrics
539
- eval_metrics = get_metrics(eval_metrics)
540
- eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
541
- eval_normalizer = eval_metrics.pop("normalizer")
542
- eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
543
- # Update progress bar
544
- epochs.desc = (
545
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
546
- )
547
- # Save metrics
548
- if has_tensorboard and jax.process_index() == 0:
549
- cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
550
- write_eval_metric(summary_writer, eval_metrics, cur_step)
551
- # save checkpoint after each epoch and push checkpoint to the hub
552
- if jax.process_index() == 0:
553
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
554
- model.save_pretrained(
555
- training_args.output_dir,
556
- params=params,
557
- push_to_hub=training_args.push_to_hub,
558
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
559
- )
 
297
  if extension == "txt":
298
  extension = "text"
299
  datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
300
+
301
+ datasets["train"] = datasets["train"].select(range(10000))
302
+ datasets["validation"] = datasets["validation"].select(range(1000))
303
+
304
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
305
  # https://huggingface.co/docs/datasets/loading_datasets.html.
306
  # Load pretrained model and tokenizer
 
516
  model_inputs = shard(model_inputs.data)
517
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
518
  train_metrics.append(train_metric)
519
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
520
  if cur_step % training_args.logging_steps == 0 and cur_step > 0:
521
  # Save metrics
522
  train_metric = jax_utils.unreplicate(train_metric)
 
527
  f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
528
  )
529
  train_metrics = []
530
+ if cur_step % training_args.eval_steps == 0 and step > 0:
531
+ # ======================== Evaluating ==============================
532
+ num_eval_samples = len(tokenized_datasets["validation"])
533
+ eval_samples_idx = jnp.arange(num_eval_samples)
534
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
535
+ eval_metrics = []
536
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
537
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
538
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
539
+ # Model forward
540
+ model_inputs = shard(model_inputs.data)
541
+ metrics = p_eval_step(state.params, model_inputs)
542
+ eval_metrics.append(metrics)
543
+ # normalize eval metrics
544
+ eval_metrics = get_metrics(eval_metrics)
545
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
546
+ eval_normalizer = eval_metrics.pop("normalizer")
547
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
548
+ # Update progress bar
549
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
550
+ # Save metrics
551
+ if has_tensorboard and jax.process_index() == 0:
552
+ cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
553
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
554
+ if cur_step % training_args.save_steps == 0 and step > 0:
555
+ # save checkpoint after each epoch and push checkpoint to the hub
556
+ if jax.process_index() == 0:
557
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
558
+ model.save_pretrained(
559
+ training_args.output_dir,
560
+ params=params,
561
+ push_to_hub=training_args.push_to_hub,
562
+ commit_message=f"Saving weights and logs of step {cur_step}",
563
+ )
tokenizer.json CHANGED
File without changes
train_tokenizer.py CHANGED
File without changes