boris commited on
Commit
02b2308
1 Parent(s): 955dc20

feat(train): google-cloud-storage is optional

Browse files
Files changed (1) hide show
  1. tools/train/train.py +12 -1
tools/train/train.py CHANGED
@@ -42,7 +42,6 @@ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
42
  from flax.serialization import from_bytes, to_bytes
43
  from flax.training import train_state
44
  from flax.training.common_utils import onehot
45
- from google.cloud import storage
46
  from jax.experimental import PartitionSpec, maps
47
  from jax.experimental.compilation_cache import compilation_cache as cc
48
  from jax.experimental.pjit import pjit, with_sharding_constraint
@@ -58,6 +57,11 @@ from dalle_mini.model import (
58
  set_partitions,
59
  )
60
 
 
 
 
 
 
61
  cc.initialize_cache("./jax_cache", max_cache_size_bytes=10 * 2**30)
62
 
63
  logger = logging.getLogger(__name__)
@@ -144,6 +148,9 @@ class ModelArguments:
144
  if self.restore_state.startswith("gs://"):
145
  bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
146
  bucket, blob_name = str(bucket_path).split("/", 1)
 
 
 
147
  client = storage.Client()
148
  bucket = client.bucket(bucket)
149
  blob = bucket.blob(blob_name)
@@ -456,6 +463,10 @@ class TrainingArguments:
456
  assert (
457
  jax.local_device_count() == 8
458
  ), "TPUs in use, please check running processes"
 
 
 
 
459
  assert self.optim in [
460
  "distributed_shampoo",
461
  "adam",
 
42
  from flax.serialization import from_bytes, to_bytes
43
  from flax.training import train_state
44
  from flax.training.common_utils import onehot
 
45
  from jax.experimental import PartitionSpec, maps
46
  from jax.experimental.compilation_cache import compilation_cache as cc
47
  from jax.experimental.pjit import pjit, with_sharding_constraint
 
57
  set_partitions,
58
  )
59
 
60
+ try:
61
+ from google.cloud import storage
62
+ except:
63
+ storage = None
64
+
65
  cc.initialize_cache("./jax_cache", max_cache_size_bytes=10 * 2**30)
66
 
67
  logger = logging.getLogger(__name__)
 
148
  if self.restore_state.startswith("gs://"):
149
  bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
150
  bucket, blob_name = str(bucket_path).split("/", 1)
151
+ assert (
152
+ storage is not None
153
+ ), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
154
  client = storage.Client()
155
  bucket = client.bucket(bucket)
156
  blob = bucket.blob(blob_name)
 
463
  assert (
464
  jax.local_device_count() == 8
465
  ), "TPUs in use, please check running processes"
466
+ if self.output_dir.startswith("gs://"):
467
+ assert (
468
+ storage is not None
469
+ ), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
470
  assert self.optim in [
471
  "distributed_shampoo",
472
  "adam",