boris commited on
Commit
88c8e06
1 Parent(s): 274ba73

feat(data): support accumulation in non-streaming

Browse files
Files changed (1) hide show
  1. src/dalle_mini/data.py +10 -2
src/dalle_mini/data.py CHANGED
@@ -161,13 +161,16 @@ class Dataset:
161
  def _dataloader_datasets_non_streaming(
162
  dataset: Dataset,
163
  per_device_batch_size: int,
 
164
  rng: jax.random.PRNGKey = None,
165
  ):
166
  """
167
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
168
  Shuffle batches if rng is set.
169
  """
170
- batch_size = per_device_batch_size * num_devices
 
 
171
  steps_per_epoch = len(dataset) // batch_size
172
 
173
  if rng is not None:
@@ -183,6 +186,11 @@ class Dataset:
183
  for idx in batch_idx:
184
  batch = dataset[idx]
185
  batch = {k: jnp.array(v) for k, v in batch.items()}
 
 
 
 
 
186
  batch = shard(batch)
187
  yield batch
188
 
@@ -244,7 +252,7 @@ class Dataset:
244
  if split == "train":
245
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
246
  return _dataloader_datasets_non_streaming(
247
- ds, per_device_batch_size, input_rng
248
  )
249
 
250
  @property
 
161
  def _dataloader_datasets_non_streaming(
162
  dataset: Dataset,
163
  per_device_batch_size: int,
164
+ gradient_accumulation_steps: int,
165
  rng: jax.random.PRNGKey = None,
166
  ):
167
  """
168
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
169
  Shuffle batches if rng is set.
170
  """
171
+ batch_size = (
172
+ per_device_batch_size * num_devices * gradient_accumulation_steps
173
+ )
174
  steps_per_epoch = len(dataset) // batch_size
175
 
176
  if rng is not None:
 
186
  for idx in batch_idx:
187
  batch = dataset[idx]
188
  batch = {k: jnp.array(v) for k, v in batch.items()}
189
+ if gradient_accumulation_steps is not None:
190
+ batch = jax.tree_map(
191
+ lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
192
+ batch,
193
+ )
194
  batch = shard(batch)
195
  yield batch
196
 
 
252
  if split == "train":
253
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
254
  return _dataloader_datasets_non_streaming(
255
+ ds, per_device_batch_size, gradient_accumulation_steps, input_rng
256
  )
257
 
258
  @property