Spaces:
Running
Running
feat: shard by host is optional
Browse files- dalle_mini/data.py +7 -2
- tools/train/train.py +8 -2
dalle_mini/data.py
CHANGED
@@ -27,6 +27,7 @@ class Dataset:
|
|
27 |
do_train: bool = False
|
28 |
do_eval: bool = True
|
29 |
seed_dataset: int = None
|
|
|
30 |
train_dataset: Dataset = field(init=False)
|
31 |
eval_dataset: Dataset = field(init=False)
|
32 |
rng_dataset: jnp.ndarray = field(init=False)
|
@@ -42,7 +43,11 @@ class Dataset:
|
|
42 |
if isinstance(f, str):
|
43 |
setattr(self, k, list(braceexpand(f)))
|
44 |
# for list of files, split training data shards by host
|
45 |
-
if
|
|
|
|
|
|
|
|
|
46 |
self.train_file = self.train_file[
|
47 |
jax.process_index() :: jax.process_count()
|
48 |
]
|
@@ -185,7 +190,7 @@ class Dataset:
|
|
185 |
first_loop = True
|
186 |
while self.multi_hosts or first_loop:
|
187 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
188 |
-
# at same
|
189 |
if not first_loop:
|
190 |
# multi-host setting, we reshuffle shards
|
191 |
epoch += 1
|
|
|
27 |
do_train: bool = False
|
28 |
do_eval: bool = True
|
29 |
seed_dataset: int = None
|
30 |
+
shard_by_host: bool = False
|
31 |
train_dataset: Dataset = field(init=False)
|
32 |
eval_dataset: Dataset = field(init=False)
|
33 |
rng_dataset: jnp.ndarray = field(init=False)
|
|
|
43 |
if isinstance(f, str):
|
44 |
setattr(self, k, list(braceexpand(f)))
|
45 |
# for list of files, split training data shards by host
|
46 |
+
if (
|
47 |
+
isinstance(self.train_file, list)
|
48 |
+
and self.multi_hosts
|
49 |
+
and self.shard_by_host
|
50 |
+
):
|
51 |
self.train_file = self.train_file[
|
52 |
jax.process_index() :: jax.process_count()
|
53 |
]
|
|
|
190 |
first_loop = True
|
191 |
while self.multi_hosts or first_loop:
|
192 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
193 |
+
# at the same time and we don't know how much data is on each host
|
194 |
if not first_loop:
|
195 |
# multi-host setting, we reshuffle shards
|
196 |
epoch += 1
|
tools/train/train.py
CHANGED
@@ -112,16 +112,22 @@ class DataTrainingArguments:
|
|
112 |
metadata={"help": "An optional input evaluation data file (glob acceptable)."},
|
113 |
)
|
114 |
# data loading should not be a bottleneck so we use "streaming" mode by default
|
115 |
-
streaming: bool = field(
|
116 |
default=True,
|
117 |
metadata={"help": "Whether to stream the dataset."},
|
118 |
)
|
119 |
-
use_auth_token: bool = field(
|
120 |
default=False,
|
121 |
metadata={
|
122 |
"help": "Whether to use the authentication token for private datasets."
|
123 |
},
|
124 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
max_train_samples: Optional[int] = field(
|
126 |
default=None,
|
127 |
metadata={
|
|
|
112 |
metadata={"help": "An optional input evaluation data file (glob acceptable)."},
|
113 |
)
|
114 |
# data loading should not be a bottleneck so we use "streaming" mode by default
|
115 |
+
streaming: Optional[bool] = field(
|
116 |
default=True,
|
117 |
metadata={"help": "Whether to stream the dataset."},
|
118 |
)
|
119 |
+
use_auth_token: Optional[bool] = field(
|
120 |
default=False,
|
121 |
metadata={
|
122 |
"help": "Whether to use the authentication token for private datasets."
|
123 |
},
|
124 |
)
|
125 |
+
shard_by_host: Optional[bool] = field(
|
126 |
+
default=False,
|
127 |
+
metadata={
|
128 |
+
"help": "Whether to shard data files by host in multi-host environments."
|
129 |
+
},
|
130 |
+
)
|
131 |
max_train_samples: Optional[int] = field(
|
132 |
default=None,
|
133 |
metadata={
|