yhavinga commited on
Commit
e867aeb
1 Parent(s): b037791

Saving weights and logs of step 300

Browse files
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7df1f5835f058622107709c8dc20a5d3452a8facd5dd852b33913d99ebc91e5a
3
  size 891548548
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:457c5948252576d9d5252b28b79a754223d3dea5a24a77f5b2b7cb5189129499
3
  size 891548548
run_t5.sh CHANGED
@@ -16,9 +16,8 @@ mkdir -p "${MODEL_DIR}/runs"
16
  --preprocessing_num_workers="96" \
17
  --do_train --do_eval \
18
  --adafactor \
19
- --dtype="bfloat16" \
20
  --max_seq_length="512" \
21
- --gradient_accumulation_steps="4" \
22
  --per_device_train_batch_size="32" \
23
  --per_device_eval_batch_size="32" \
24
  --learning_rate="5e-3" \
@@ -32,3 +31,7 @@ mkdir -p "${MODEL_DIR}/runs"
32
  #git add pytorch_model.bin
33
  #git commit -m "Update pytorch model after training"
34
  #git push origin main
 
 
 
 
 
16
  --preprocessing_num_workers="96" \
17
  --do_train --do_eval \
18
  --adafactor \
 
19
  --max_seq_length="512" \
20
+ --gradient_accumulation_steps="16" \
21
  --per_device_train_batch_size="32" \
22
  --per_device_eval_batch_size="32" \
23
  --learning_rate="5e-3" \
 
31
  #git add pytorch_model.bin
32
  #git commit -m "Update pytorch model after training"
33
  #git push origin main
34
+
35
+
36
+ # --dtype="bfloat16" \
37
+ # --resume_from_checkpoint="${MODEL_DIR}/ckpt-3300" \
run_t5_mlm_flax_custom_dataset.py CHANGED
@@ -722,6 +722,9 @@ if __name__ == "__main__":
722
 
723
  num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
724
 
 
 
 
725
  # Create learning rate schedule
726
 
727
  # See https://arxiv.org/pdf/2104.07705.pdf for rationale of choosing the peak at 6% of training steps
@@ -775,6 +778,11 @@ if __name__ == "__main__":
775
  # Setup train state
776
  state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
777
 
 
 
 
 
 
778
  # Define gradient update step fn
779
  def train_step(state, batch, dropout_rng):
780
  dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
@@ -828,8 +836,7 @@ if __name__ == "__main__":
828
  # Replicate the train state on each device
829
  state = jax_utils.replicate(state)
830
 
831
- steps_per_epoch = len(datasets['train']) // train_batch_size
832
- total_train_steps = steps_per_epoch * num_epochs
833
 
834
  logger.info("***** Running training *****")
835
  logger.info(f" Num examples = {len(datasets['train'])}")
@@ -855,6 +862,11 @@ if __name__ == "__main__":
855
 
856
  # Gather the indexes for creating the batch and do a training step
857
  for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
 
 
 
 
 
858
  samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
859
  model_inputs = data_collator(samples)
860
 
@@ -863,7 +875,6 @@ if __name__ == "__main__":
863
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
864
  train_metrics.append(train_metric)
865
 
866
- cur_step = epoch * (num_train_samples // train_batch_size) + step
867
 
868
  if cur_step % training_args.logging_steps * grad_accum_steps == 0 and cur_step > 0:
869
  # Save metrics
 
722
 
723
  num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
724
 
725
+ steps_per_epoch = len(tokenized_datasets['train']) // train_batch_size
726
+ total_train_steps = steps_per_epoch * num_epochs
727
+
728
  # Create learning rate schedule
729
 
730
  # See https://arxiv.org/pdf/2104.07705.pdf for rationale of choosing the peak at 6% of training steps
 
778
  # Setup train state
779
  state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
780
 
781
+ if training_args.resume_from_checkpoint:
782
+ state, resume_step = restore_checkpoint(training_args.resume_from_checkpoint, state)
783
+ else:
784
+ resume_step = 0
785
+
786
  # Define gradient update step fn
787
  def train_step(state, batch, dropout_rng):
788
  dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
 
836
  # Replicate the train state on each device
837
  state = jax_utils.replicate(state)
838
 
839
+
 
840
 
841
  logger.info("***** Running training *****")
842
  logger.info(f" Num examples = {len(datasets['train'])}")
 
862
 
863
  # Gather the indexes for creating the batch and do a training step
864
  for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
865
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
866
+ # skip to the step from which we are resuming
867
+ if cur_step < resume_step:
868
+ continue
869
+
870
  samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
871
  model_inputs = data_collator(samples)
872
 
 
875
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
876
  train_metrics.append(train_metric)
877
 
 
878
 
879
  if cur_step % training_args.logging_steps * grad_accum_steps == 0 and cur_step > 0:
880
  # Save metrics
runs/Jul10_07-37-20_t1v-n-0e7426e8-w-0/events.out.tfevents.1625902752.t1v-n-0e7426e8-w-0.18397.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1aa4fd14ba6d0007ac2b4c7ad5f7b03ab486b3899ece3eba1fefe852923f2366
3
+ size 40
runs/Jul10_07-45-49_t1v-n-0e7426e8-w-0/events.out.tfevents.1625903173.t1v-n-0e7426e8-w-0.20563.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9086b97ea9ba59e96e4c66b26c205fe1207d0a94ab355127a1e4f8078d84a269
3
+ size 45399