boris commited on
Commit
0c9ff65
1 Parent(s): 1c9c679

fix(seq2seq): opt_state from ckpt + limit cache

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +6 -5
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -20,10 +20,6 @@ Script adapted from run_summarization_flax.py
20
  # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
21
 
22
  import os
23
- # set a common huggingface cache folder (used with datasets and transformers) and wandb cache folder (used with artifacts)
24
- os.environ['HF_HOME'] = '/data/huggingface/' # required before importing transformers & datasets
25
- os.environ['WANDB_CACHE_DIR'] = '/data/wandb/' # required before importing wandb
26
-
27
  import logging as pylogging # To avoid collision with transformers.utils.logging
28
  import sys
29
  import time
@@ -442,6 +438,7 @@ def main():
442
  if (Path(artifact_dir) / 'opt_state.msgpack').exists():
443
  with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
444
  opt_state = from_bytes(state.opt_state, f.read())
 
445
 
446
  # restore steps
447
  if (Path(artifact_dir) / 'training_state.json').exists():
@@ -836,6 +833,10 @@ def main():
836
  artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
837
  wandb.run.log_artifact(artifact)
838
 
 
 
 
 
839
  # save to the hub
840
  if training_args.push_to_hub:
841
  model.save_pretrained(
@@ -866,7 +867,7 @@ def main():
866
  # log metrics
867
  wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
868
 
869
- if global_step % training_args.eval_steps == 0:
870
  run_evaluation()
871
 
872
  if global_step % data_args.save_model_steps == 0:
 
20
  # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
21
 
22
  import os
 
 
 
 
23
  import logging as pylogging # To avoid collision with transformers.utils.logging
24
  import sys
25
  import time
 
438
  if (Path(artifact_dir) / 'opt_state.msgpack').exists():
439
  with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
440
  opt_state = from_bytes(state.opt_state, f.read())
441
+ state.replace(opt_state=opt_state)
442
 
443
  # restore steps
444
  if (Path(artifact_dir) / 'training_state.json').exists():
 
833
  artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
834
  wandb.run.log_artifact(artifact)
835
 
836
+ # save some space
837
+ c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
838
+ c.cleanup(wandb.util.from_human_size("15GB"))
839
+
840
  # save to the hub
841
  if training_args.push_to_hub:
842
  model.save_pretrained(
 
867
  # log metrics
868
  wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
869
 
870
+ if training_args.eval_steps and global_step % training_args.eval_steps == 0:
871
  run_evaluation()
872
 
873
  if global_step % data_args.save_model_steps == 0: