boris commited on
Commit
ed93c8a
1 Parent(s): a6252c9

feat: split shards by host

Browse files
Files changed (1) hide show
  1. dalle_mini/data.py +30 -11
dalle_mini/data.py CHANGED
@@ -4,9 +4,9 @@ from functools import partial
4
  import jax
5
  import jax.numpy as jnp
6
  import numpy as np
 
7
  from datasets import Dataset, load_dataset
8
  from flax.training.common_utils import shard
9
- from braceexpand import braceexpand
10
 
11
  from .text import TextNormalizer
12
 
@@ -30,8 +30,10 @@ class Dataset:
30
  train_dataset: Dataset = field(init=False)
31
  eval_dataset: Dataset = field(init=False)
32
  rng_dataset: jnp.ndarray = field(init=False)
 
33
 
34
  def __post_init__(self):
 
35
  # define data_files
36
  if self.train_file is not None or self.validation_file is not None:
37
  # accept braceexpand notation
@@ -39,6 +41,11 @@ class Dataset:
39
  f = getattr(self, k)
40
  if isinstance(f, str):
41
  setattr(self, k, list(braceexpand(f)))
 
 
 
 
 
42
  data_files = {
43
  "train": self.train_file,
44
  "validation": self.validation_file,
@@ -169,17 +176,29 @@ class Dataset:
169
  batch = shard(batch)
170
  yield batch
171
 
172
- def _dataloader_datasets_streaming(dataset: Dataset, batch_size: int):
 
 
 
173
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
174
  batch = {k: [] for k in keys}
175
- for item in dataset:
176
- for k, v in item.items():
177
- batch[k].append(v)
178
- if len(batch[keys[0]]) == batch_size:
179
- batch = {k: jnp.array(v) for k, v in batch.items()}
180
- batch = shard(batch)
181
- yield batch
182
- batch = {k: [] for k in keys}
 
 
 
 
 
 
 
 
 
183
 
184
  if split == "train":
185
  ds = self.train_dataset
@@ -191,7 +210,7 @@ class Dataset:
191
  if self.streaming:
192
  if split == "train":
193
  ds.set_epoch(epoch)
194
- return _dataloader_datasets_streaming(ds, batch_size)
195
  else:
196
  if split == "train":
197
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
 
4
  import jax
5
  import jax.numpy as jnp
6
  import numpy as np
7
+ from braceexpand import braceexpand
8
  from datasets import Dataset, load_dataset
9
  from flax.training.common_utils import shard
 
10
 
11
  from .text import TextNormalizer
12
 
 
30
  train_dataset: Dataset = field(init=False)
31
  eval_dataset: Dataset = field(init=False)
32
  rng_dataset: jnp.ndarray = field(init=False)
33
+ multi_hosts: bool = field(init=False)
34
 
35
  def __post_init__(self):
36
+ self.multi_hosts = jax.process_count > 1
37
  # define data_files
38
  if self.train_file is not None or self.validation_file is not None:
39
  # accept braceexpand notation
 
41
  f = getattr(self, k)
42
  if isinstance(f, str):
43
  setattr(self, k, list(braceexpand(f)))
44
+ # for list of files, split training data shards by host
45
+ if isinstance(self.train_file, list) and self.multi_hosts:
46
+ self.train_file = self.train_file[
47
+ jax.process_index() :: jax.process_count()
48
+ ]
49
  data_files = {
50
  "train": self.train_file,
51
  "validation": self.validation_file,
 
176
  batch = shard(batch)
177
  yield batch
178
 
179
+ def _dataloader_datasets_streaming(
180
+ dataset: Dataset, batch_size: int, epoch: int
181
+ ):
182
+ # epoch is only use for multi-host
183
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
184
  batch = {k: [] for k in keys}
185
+ first_loop = True
186
+ while self.multi_hosts or first_loop:
187
+ # in multi-host, we run forever (no epoch) as hosts need to stop
188
+ # at same the time and we don't know how much data is on each host
189
+ if not first_loop:
190
+ # multi-host setting, we reshuffle shards
191
+ epoch += 1
192
+ dataset.set_epoch(epoch)
193
+ for item in dataset:
194
+ for k, v in item.items():
195
+ batch[k].append(v)
196
+ if len(batch[keys[0]]) == batch_size:
197
+ batch = {k: jnp.array(v) for k, v in batch.items()}
198
+ batch = shard(batch)
199
+ yield batch
200
+ batch = {k: [] for k in keys}
201
+ first_loop = False
202
 
203
  if split == "train":
204
  ds = self.train_dataset
 
210
  if self.streaming:
211
  if split == "train":
212
  ds.set_epoch(epoch)
213
+ return _dataloader_datasets_streaming(ds, batch_size, epoch)
214
  else:
215
  if split == "train":
216
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)