boris commited on
Commit
b7c7458
1 Parent(s): 8149924

fix(train): consider correct batch size

Browse files
Files changed (2) hide show
  1. src/dalle_mini/data.py +14 -27
  2. tools/train/train.py +11 -5
src/dalle_mini/data.py CHANGED
@@ -156,21 +156,19 @@ class Dataset:
156
  self, split, per_device_batch_size, gradient_accumulation_steps=None, epoch=None
157
  ):
158
  num_devices = jax.local_device_count()
 
 
 
159
 
160
  def _dataloader_datasets_non_streaming(
161
  dataset: Dataset,
162
- per_device_batch_size: int,
163
- gradient_accumulation_steps: int,
164
  rng: jax.random.PRNGKey = None,
165
  ):
166
  """
167
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
168
  Shuffle batches if rng is set.
169
  """
170
- batch_size = (
171
- per_device_batch_size * num_devices * gradient_accumulation_steps
172
- )
173
- steps_per_epoch = len(dataset) // batch_size
174
 
175
  if rng is not None:
176
  batch_idx = jax.random.permutation(rng, len(dataset))
@@ -178,25 +176,24 @@ class Dataset:
178
  batch_idx = jnp.arange(len(dataset))
179
 
180
  batch_idx = batch_idx[
181
- : steps_per_epoch * batch_size
182
  ] # Skip incomplete batch.
183
- batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
184
 
185
  for idx in batch_idx:
186
  batch = dataset[idx]
187
  batch = {k: jnp.array(v) for k, v in batch.items()}
188
  if gradient_accumulation_steps is not None:
189
  batch = jax.tree_map(
190
- lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
 
 
191
  batch,
192
  )
193
  yield batch
194
 
195
  def _dataloader_datasets_streaming(
196
  dataset: Dataset,
197
- split: str,
198
- per_device_batch_size: int,
199
- gradient_accumulation_steps: int,
200
  epoch: int,
201
  ):
202
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
@@ -214,19 +211,13 @@ class Dataset:
214
  for item in dataset:
215
  for k, v in item.items():
216
  batch[k].append(v)
217
- # batch = 5, devices = 8, accumulation = 2 / batch_size = 5 x 8
218
- # (40, 3, 3) -> shard 8 x (5, 3, 3)
219
- # (16, 5, 3, 3) -> shard 8 x (2, 5, 3, 3)
220
- if len(batch[keys[0]]) == per_device_batch_size * num_devices * (
221
- gradient_accumulation_steps
222
- if gradient_accumulation_steps is not None
223
- else 1
224
- ):
225
  batch = {k: jnp.array(v) for k, v in batch.items()}
226
  if gradient_accumulation_steps is not None:
 
227
  batch = jax.tree_map(
228
  lambda x: x.reshape(
229
- (-1, per_device_batch_size) + x.shape[1:]
230
  ),
231
  batch,
232
  )
@@ -242,15 +233,11 @@ class Dataset:
242
  raise ValueError(f'split must be "train" or "eval", got {split}')
243
 
244
  if self.streaming:
245
- return _dataloader_datasets_streaming(
246
- ds, split, per_device_batch_size, gradient_accumulation_steps, epoch
247
- )
248
  else:
249
  if split == "train":
250
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
251
- return _dataloader_datasets_non_streaming(
252
- ds, per_device_batch_size, gradient_accumulation_steps, input_rng
253
- )
254
 
255
  @property
256
  def length(self):
 
156
  self, split, per_device_batch_size, gradient_accumulation_steps=None, epoch=None
157
  ):
158
  num_devices = jax.local_device_count()
159
+ total_batch_size = per_device_batch_size * num_devices
160
+ if gradient_accumulation_steps is not None:
161
+ total_batch_size *= gradient_accumulation_steps
162
 
163
  def _dataloader_datasets_non_streaming(
164
  dataset: Dataset,
 
 
165
  rng: jax.random.PRNGKey = None,
166
  ):
167
  """
168
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
169
  Shuffle batches if rng is set.
170
  """
171
+ steps_per_epoch = len(dataset) // total_batch_size
 
 
 
172
 
173
  if rng is not None:
174
  batch_idx = jax.random.permutation(rng, len(dataset))
 
176
  batch_idx = jnp.arange(len(dataset))
177
 
178
  batch_idx = batch_idx[
179
+ : steps_per_epoch * total_batch_size
180
  ] # Skip incomplete batch.
181
+ batch_idx = batch_idx.reshape((steps_per_epoch, total_batch_size))
182
 
183
  for idx in batch_idx:
184
  batch = dataset[idx]
185
  batch = {k: jnp.array(v) for k, v in batch.items()}
186
  if gradient_accumulation_steps is not None:
187
  batch = jax.tree_map(
188
+ lambda x: x.reshape(
189
+ (gradient_accumulation_steps, -1) + x.shape[1:]
190
+ ),
191
  batch,
192
  )
193
  yield batch
194
 
195
  def _dataloader_datasets_streaming(
196
  dataset: Dataset,
 
 
 
197
  epoch: int,
198
  ):
199
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
 
211
  for item in dataset:
212
  for k, v in item.items():
213
  batch[k].append(v)
214
+ if len(batch[keys[0]]) == total_batch_size:
 
 
 
 
 
 
 
215
  batch = {k: jnp.array(v) for k, v in batch.items()}
216
  if gradient_accumulation_steps is not None:
217
+ # training mode
218
  batch = jax.tree_map(
219
  lambda x: x.reshape(
220
+ (gradient_accumulation_steps, -1) + x.shape[1:]
221
  ),
222
  batch,
223
  )
 
233
  raise ValueError(f'split must be "train" or "eval", got {split}')
234
 
235
  if self.streaming:
236
+ return _dataloader_datasets_streaming(ds, epoch)
 
 
237
  else:
238
  if split == "train":
239
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
240
+ return _dataloader_datasets_non_streaming(ds, input_rng)
 
 
241
 
242
  @property
243
  def length(self):
tools/train/train.py CHANGED
@@ -549,11 +549,11 @@ def main():
549
 
550
  # Store some constant
551
  num_epochs = training_args.num_train_epochs
552
- # batch size per node
553
- train_batch_size = (
554
  training_args.per_device_train_batch_size * jax.local_device_count()
555
  )
556
- batch_size_per_node = train_batch_size * training_args.gradient_accumulation_steps
557
  batch_size_per_step = batch_size_per_node * jax.process_count()
558
  eval_batch_size = (
559
  training_args.per_device_eval_batch_size * jax.local_device_count()
@@ -770,6 +770,12 @@ def main():
770
 
771
  # Define gradient update step fn
772
  def train_step(state, batch, delta_time):
 
 
 
 
 
 
773
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
774
  # use a different rng per node
775
  dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
@@ -837,13 +843,13 @@ def main():
837
  # Create parallel version of the train and eval step
838
  p_train_step = pjit(
839
  train_step,
840
- in_axis_resources=(state_spec, PartitionSpec("batch", None), None),
841
  out_axis_resources=(state_spec, None),
842
  donate_argnums=(0,),
843
  )
844
  p_eval_step = pjit(
845
  eval_step,
846
- in_axis_resources=(param_spec, PartitionSpec("batch", None)),
847
  out_axis_resources=None,
848
  )
849
 
 
549
 
550
  # Store some constant
551
  num_epochs = training_args.num_train_epochs
552
+ # batch size
553
+ minibatch_size = (
554
  training_args.per_device_train_batch_size * jax.local_device_count()
555
  )
556
+ batch_size_per_node = minibatch_size * training_args.gradient_accumulation_steps
557
  batch_size_per_step = batch_size_per_node * jax.process_count()
558
  eval_batch_size = (
559
  training_args.per_device_eval_batch_size * jax.local_device_count()
 
770
 
771
  # Define gradient update step fn
772
  def train_step(state, batch, delta_time):
773
+ # check correct batch shape during compilation
774
+ assert batch["labels"].shape[0:2] == (
775
+ training_args.gradient_accumulation_steps,
776
+ minibatch_size,
777
+ ), f"Expected label batch of shape gradient_acculumation x minibatch_size x items and got {batch['labels'].shape}"
778
+ # create a new rng
779
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
780
  # use a different rng per node
781
  dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
 
843
  # Create parallel version of the train and eval step
844
  p_train_step = pjit(
845
  train_step,
846
+ in_axis_resources=(state_spec, PartitionSpec(None, "batch"), None),
847
  out_axis_resources=(state_spec, None),
848
  donate_argnums=(0,),
849
  )
850
  p_eval_step = pjit(
851
  eval_step,
852
+ in_axis_resources=(param_spec, PartitionSpec("batch")),
853
  out_axis_resources=None,
854
  )
855