Spaces:
Running
Running
feat(train): google-cloud-storage is optional
Browse files- tools/train/train.py +12 -1
tools/train/train.py
CHANGED
@@ -42,7 +42,6 @@ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
|
42 |
from flax.serialization import from_bytes, to_bytes
|
43 |
from flax.training import train_state
|
44 |
from flax.training.common_utils import onehot
|
45 |
-
from google.cloud import storage
|
46 |
from jax.experimental import PartitionSpec, maps
|
47 |
from jax.experimental.compilation_cache import compilation_cache as cc
|
48 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
@@ -58,6 +57,11 @@ from dalle_mini.model import (
|
|
58 |
set_partitions,
|
59 |
)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
61 |
cc.initialize_cache("./jax_cache", max_cache_size_bytes=10 * 2**30)
|
62 |
|
63 |
logger = logging.getLogger(__name__)
|
@@ -144,6 +148,9 @@ class ModelArguments:
|
|
144 |
if self.restore_state.startswith("gs://"):
|
145 |
bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
|
146 |
bucket, blob_name = str(bucket_path).split("/", 1)
|
|
|
|
|
|
|
147 |
client = storage.Client()
|
148 |
bucket = client.bucket(bucket)
|
149 |
blob = bucket.blob(blob_name)
|
@@ -456,6 +463,10 @@ class TrainingArguments:
|
|
456 |
assert (
|
457 |
jax.local_device_count() == 8
|
458 |
), "TPUs in use, please check running processes"
|
|
|
|
|
|
|
|
|
459 |
assert self.optim in [
|
460 |
"distributed_shampoo",
|
461 |
"adam",
|
|
|
42 |
from flax.serialization import from_bytes, to_bytes
|
43 |
from flax.training import train_state
|
44 |
from flax.training.common_utils import onehot
|
|
|
45 |
from jax.experimental import PartitionSpec, maps
|
46 |
from jax.experimental.compilation_cache import compilation_cache as cc
|
47 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
|
|
57 |
set_partitions,
|
58 |
)
|
59 |
|
60 |
+
try:
|
61 |
+
from google.cloud import storage
|
62 |
+
except:
|
63 |
+
storage = None
|
64 |
+
|
65 |
cc.initialize_cache("./jax_cache", max_cache_size_bytes=10 * 2**30)
|
66 |
|
67 |
logger = logging.getLogger(__name__)
|
|
|
148 |
if self.restore_state.startswith("gs://"):
|
149 |
bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
|
150 |
bucket, blob_name = str(bucket_path).split("/", 1)
|
151 |
+
assert (
|
152 |
+
storage is not None
|
153 |
+
), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
|
154 |
client = storage.Client()
|
155 |
bucket = client.bucket(bucket)
|
156 |
blob = bucket.blob(blob_name)
|
|
|
463 |
assert (
|
464 |
jax.local_device_count() == 8
|
465 |
), "TPUs in use, please check running processes"
|
466 |
+
if self.output_dir.startswith("gs://"):
|
467 |
+
assert (
|
468 |
+
storage is not None
|
469 |
+
), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
|
470 |
assert self.optim in [
|
471 |
"distributed_shampoo",
|
472 |
"adam",
|