boris commited on
Commit
9f5e879
1 Parent(s): d368fb6

feat(train): local jax cache

Browse files
Files changed (1) hide show
  1. tools/train/train.py +1 -3
tools/train/train.py CHANGED
@@ -57,9 +57,7 @@ from dalle_mini.model import (
57
  set_partitions,
58
  )
59
 
60
- cc.initialize_cache(
61
- "/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2**30
62
- )
63
 
64
  logger = logging.getLogger(__name__)
65
 
 
57
  set_partitions,
58
  )
59
 
60
+ cc.initialize_cache("./jax_cache", max_cache_size_bytes=5 * 2**30)
 
 
61
 
62
  logger = logging.getLogger(__name__)
63