Commit
•
b1b3841
1
Parent(s):
37fa322
Saving weights and logs of step 8
Browse files- README.md +0 -0
- config.json +0 -0
- create_config.py +0 -0
- flax_model.msgpack +3 -0
- run.sh +5 -2
- run.sh.save +16 -0
- run_mlm_flax.py +39 -35
- tokenizer.json +0 -0
- train_tokenizer.py +0 -0
README.md
CHANGED
File without changes
|
config.json
CHANGED
File without changes
|
create_config.py
CHANGED
File without changes
|
flax_model.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a2e26684c7b415b88900d2b10f657004a4262d41aca55f28d52013b051535c43
|
3 |
+
size 498796983
|
run.sh
CHANGED
@@ -7,10 +7,13 @@
|
|
7 |
--dataset_name="oscar" \
|
8 |
--dataset_config_name="unshuffled_deduplicated_it" \
|
9 |
--max_seq_length="128" \
|
10 |
-
--per_device_train_batch_size="
|
11 |
-
--per_device_eval_batch_size="
|
12 |
--learning_rate="3e-4" \
|
13 |
--warmup_steps="1000" \
|
14 |
--overwrite_output_dir \
|
15 |
--num_train_epochs="8" \
|
|
|
|
|
|
|
16 |
--push_to_hub
|
|
|
7 |
--dataset_name="oscar" \
|
8 |
--dataset_config_name="unshuffled_deduplicated_it" \
|
9 |
--max_seq_length="128" \
|
10 |
+
--per_device_train_batch_size="1" \
|
11 |
+
--per_device_eval_batch_size="1" \
|
12 |
--learning_rate="3e-4" \
|
13 |
--warmup_steps="1000" \
|
14 |
--overwrite_output_dir \
|
15 |
--num_train_epochs="8" \
|
16 |
+
--logging_steps="10" \
|
17 |
+
--save_steps="8" \
|
18 |
+
--eval_steps="15" \
|
19 |
--push_to_hub
|
run.sh.save
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/usr/bin/env bash
|
2 |
+
./run_mlm_flax.py \
|
3 |
+
--output_dir="./" \
|
4 |
+
--model_type="roberta" \
|
5 |
+
--config_name="./" \
|
6 |
+
--tokenizer_name="./" \
|
7 |
+
--dataset_name="oscar" \
|
8 |
+
--dataset_config_name="unshuffled_deduplicated_it" \
|
9 |
+
--max_seq_length="128" \
|
10 |
+
--per_device_train_batch_size="4" \
|
11 |
+
--per_device_eval_batch_size="4" \
|
12 |
+
--learning_rate="3e-4" \
|
13 |
+
--warmup_steps="1000" \
|
14 |
+
--overwrite_output_dir \
|
15 |
+
--num_train_epochs="8" \
|
16 |
+
--push_to_hub
|
run_mlm_flax.py
CHANGED
@@ -297,6 +297,10 @@ if __name__ == "__main__":
|
|
297 |
if extension == "txt":
|
298 |
extension = "text"
|
299 |
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
|
|
|
|
|
|
|
|
300 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
301 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
302 |
# Load pretrained model and tokenizer
|
@@ -512,7 +516,7 @@ if __name__ == "__main__":
|
|
512 |
model_inputs = shard(model_inputs.data)
|
513 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
514 |
train_metrics.append(train_metric)
|
515 |
-
cur_step = epoch * num_train_samples + step
|
516 |
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
517 |
# Save metrics
|
518 |
train_metric = jax_utils.unreplicate(train_metric)
|
@@ -523,37 +527,37 @@ if __name__ == "__main__":
|
|
523 |
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
524 |
)
|
525 |
train_metrics = []
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
|
|
297 |
if extension == "txt":
|
298 |
extension = "text"
|
299 |
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
300 |
+
|
301 |
+
datasets["train"] = datasets["train"].select(range(10000))
|
302 |
+
datasets["validation"] = datasets["validation"].select(range(1000))
|
303 |
+
|
304 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
305 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
306 |
# Load pretrained model and tokenizer
|
|
|
516 |
model_inputs = shard(model_inputs.data)
|
517 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
518 |
train_metrics.append(train_metric)
|
519 |
+
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
520 |
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
521 |
# Save metrics
|
522 |
train_metric = jax_utils.unreplicate(train_metric)
|
|
|
527 |
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
528 |
)
|
529 |
train_metrics = []
|
530 |
+
if cur_step % training_args.eval_steps == 0 and step > 0:
|
531 |
+
# ======================== Evaluating ==============================
|
532 |
+
num_eval_samples = len(tokenized_datasets["validation"])
|
533 |
+
eval_samples_idx = jnp.arange(num_eval_samples)
|
534 |
+
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
535 |
+
eval_metrics = []
|
536 |
+
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
537 |
+
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
538 |
+
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
539 |
+
# Model forward
|
540 |
+
model_inputs = shard(model_inputs.data)
|
541 |
+
metrics = p_eval_step(state.params, model_inputs)
|
542 |
+
eval_metrics.append(metrics)
|
543 |
+
# normalize eval metrics
|
544 |
+
eval_metrics = get_metrics(eval_metrics)
|
545 |
+
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
546 |
+
eval_normalizer = eval_metrics.pop("normalizer")
|
547 |
+
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
548 |
+
# Update progress bar
|
549 |
+
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
550 |
+
# Save metrics
|
551 |
+
if has_tensorboard and jax.process_index() == 0:
|
552 |
+
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
553 |
+
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
554 |
+
if cur_step % training_args.save_steps == 0 and step > 0:
|
555 |
+
# save checkpoint after each epoch and push checkpoint to the hub
|
556 |
+
if jax.process_index() == 0:
|
557 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
558 |
+
model.save_pretrained(
|
559 |
+
training_args.output_dir,
|
560 |
+
params=params,
|
561 |
+
push_to_hub=training_args.push_to_hub,
|
562 |
+
commit_message=f"Saving weights and logs of step {cur_step}",
|
563 |
+
)
|
tokenizer.json
CHANGED
File without changes
|
train_tokenizer.py
CHANGED
File without changes
|