yhavinga commited on
Commit
d9993eb
1 Parent(s): a1d2f2c

Update scripts to work around collator valueerror. Update weights

Browse files
config.json CHANGED
@@ -1,4 +1,5 @@
1
  {
 
2
  "architectures": [
3
  "T5ForConditionalGeneration"
4
  ],
@@ -50,6 +51,7 @@
50
  "prefix": "translate English to Romanian: "
51
  }
52
  },
 
53
  "transformers_version": "4.9.0.dev0",
54
  "use_cache": true,
55
  "vocab_size": 32103
1
  {
2
+ "_name_or_path": ".",
3
  "architectures": [
4
  "T5ForConditionalGeneration"
5
  ],
51
  "prefix": "translate English to Romanian: "
52
  }
53
  },
54
+ "torch_dtype": "float32",
55
  "transformers_version": "4.9.0.dev0",
56
  "use_cache": true,
57
  "vocab_size": 32103
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:02c8aedd34c528d3a7806d216941cc23732a751a8d687f8bf1db06eb1e1e75a3
3
  size 891548548
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c8d5a4eb1275b4c679b148f38edb974772997a3925809f39095204009f83502
3
  size 891548548
opt_state.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97c0ff372805930fa4d7e81ae09094b7daf3cc2c1ba06224fc522a8e672af91a
3
+ size 1985609
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1a8f60fdc3ad43a82bab7ec3dcaf1138179d7508798267becb15426d86b9385f
3
  size 891650495
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:782edc5c7aa8aa66320a3417abff572760287ee6a7759f1867486d2217563650
3
  size 891650495
run_t5.sh CHANGED
@@ -7,28 +7,42 @@ mkdir -p "${MODEL_DIR}/runs"
7
  # T5 paper lr 0.01 with batch size 128
8
  # We have a batch size of 8 devices * 32 = 256, so lr = 0.01/2
9
 
10
- ./run_t5_mlm_flax_custom_dataset.py \
11
- --output_dir="${MODEL_DIR}" \
12
- --model_type="t5" \
13
- --config_name="flax-community/${MODEL}" \
14
- --tokenizer_name="${MODEL_DIR}" \
15
- --preprocessing_num_workers="96" \
16
- --do_train --do_eval \
17
- --adafactor \
18
- --max_seq_length="512" \
19
- --per_device_train_batch_size="32" \
20
- --per_device_eval_batch_size="32" \
21
- --learning_rate="5e-3" \
22
- --dtype="bfloat16" \
23
- --overwrite_output_dir \
24
- --num_train_epochs="3" \
25
- --logging_steps="50" \
26
- --save_steps="2000" \
27
- --eval_steps="10000000" \
28
- --resume_from_checkpoint="${MODEL_DIR}/ckpt-18000" \
29
- --warmup_steps="3413" \
30
- --push_to_hub
31
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  #git add pytorch_model.bin
@@ -37,3 +51,4 @@ mkdir -p "${MODEL_DIR}/runs"
37
 
38
  # --gradient_accumulation_steps="2" \
39
 
 
7
  # T5 paper lr 0.01 with batch size 128
8
  # We have a batch size of 8 devices * 32 = 256, so lr = 0.01/2
9
 
10
+ while true; do
11
+
12
+ # Set the seed to random before each run, so date shuffling per epoch is different each run.
13
+ # This kills reproducibility, but is required as long as during training ValueError can be raised.
14
+ SEED=$RANDOM
15
+
16
+ ./run_t5_mlm_flax_custom_dataset.py \
17
+ --output_dir="${MODEL_DIR}" \
18
+ --model_type="t5" \
19
+ --config_name="flax-community/${MODEL}" \
20
+ --tokenizer_name="${MODEL_DIR}" \
21
+ --seed="${SEED}" \
22
+ --preprocessing_num_workers="96" \
23
+ --do_train --do_eval \
24
+ --adafactor \
25
+ --max_seq_length="512" \
26
+ --per_device_train_batch_size="32" \
27
+ --per_device_eval_batch_size="32" \
28
+ --learning_rate="5e-3" \
29
+ --dtype="bfloat16" \
30
+ --overwrite_output_dir \
31
+ --num_train_epochs="3" \
32
+ --logging_steps="50" \
33
+ --save_steps="501" \
34
+ --eval_steps="10000000" \
35
+ --resume_from_checkpoint="${MODEL_DIR}" \
36
+ --warmup_steps="3413"
37
+
38
+ # \
39
+ # --push_to_hub
40
+
41
+ echo "RESTARTING"
42
+ sleep 20
43
+ done
44
+ #
45
+ # \
46
 
47
 
48
  #git add pytorch_model.bin
51
 
52
  # --gradient_accumulation_steps="2" \
53
 
54
+ # --resume_from_checkpoint="${MODEL_DIR}/ckpt-18000" \
run_t5_mlm_flax_custom_dataset.py CHANGED
@@ -432,6 +432,11 @@ def save_checkpoint(model, save_dir, state, with_opt: bool = True):
432
  push_to_hub=training_args.push_to_hub,
433
  commit_message=f"Saving weights and logs of step {cur_step}",
434
  )
 
 
 
 
 
435
  logger.info("checkpoint saved")
436
 
437
 
432
  push_to_hub=training_args.push_to_hub,
433
  commit_message=f"Saving weights and logs of step {cur_step}",
434
  )
435
+ if with_opt:
436
+ with open(os.path.join(training_args.output_dir, "opt_state.msgpack"), "wb") as f:
437
+ f.write(to_bytes(state.opt_state))
438
+ with open(os.path.join(training_args.output_dir, "training_state.json"), "w") as f:
439
+ json.dump({"step": state.step.item()}, f)
440
  logger.info("checkpoint saved")
441
 
442
 
runs/{Jul11_17-06-36_t1v-n-0e7426e8-w-0/events.out.tfevents.1626023202.t1v-n-0e7426e8-w-0.178001.3.v2 → Jul12_06-43-08_t1v-n-0e7426e8-w-0/events.out.tfevents.1626072193.t1v-n-0e7426e8-w-0.238699.3.v2} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0b89824cdb72fe97627209c68074b163e725d00349a36ed38b233e7d579e1b92
3
- size 296685
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f5f6fcc83f8cf7fac87cc276fa00a02c9ce4e252c6bb69a3988452bed73f67e
3
+ size 200238
training_state.json ADDED
@@ -0,0 +1 @@
 
1
+ {"step": 15004}