boris commited on
Commit
da9367c
1 Parent(s): 68cc185

feat(train): use compilation cache

Browse files
Files changed (1) hide show
  1. tools/train/train.py +6 -0
tools/train/train.py CHANGED
@@ -41,6 +41,7 @@ from flax.serialization import from_bytes, to_bytes
41
  from flax.training import train_state
42
  from flax.training.common_utils import onehot
43
  from jax.experimental import PartitionSpec, maps
 
44
  from jax.experimental.pjit import pjit, with_sharding_constraint
45
  from tqdm import tqdm
46
  from transformers import HfArgumentParser
@@ -53,6 +54,11 @@ from dalle_mini.model import (
53
  set_partitions,
54
  )
55
 
 
 
 
 
 
56
  logger = logging.getLogger(__name__)
57
 
58
 
 
41
  from flax.training import train_state
42
  from flax.training.common_utils import onehot
43
  from jax.experimental import PartitionSpec, maps
44
+ from jax.experimental.compilation_cache import compilation_cache as cc
45
  from jax.experimental.pjit import pjit, with_sharding_constraint
46
  from tqdm import tqdm
47
  from transformers import HfArgumentParser
 
54
  set_partitions,
55
  )
56
 
57
+ cc.initialize_cache(
58
+ "/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2**30
59
+ )
60
+
61
+
62
  logger = logging.getLogger(__name__)
63
 
64