boris commited on
Commit
1bb3269
1 Parent(s): 5f28cd2

feat: handle model parallel

Browse files
src/dalle_mini/data.py CHANGED
@@ -85,7 +85,12 @@ class Dataset:
85
  else self.eval_dataset.select(range(self.max_eval_samples))
86
  )
87
 
88
- def preprocess(self, tokenizer, decoder_start_token_id, normalize_text, max_length):
 
 
 
 
 
89
  if self.streaming:
90
  # we need to shuffle early in streaming mode
91
  if hasattr(self, "train_dataset"):
 
85
  else self.eval_dataset.select(range(self.max_eval_samples))
86
  )
87
 
88
+ def preprocess(self, tokenizer, config):
89
+ # get required config variables
90
+ decoder_start_token_id = config.decoder_start_token_id
91
+ normalize_text = config.normalize_text
92
+ max_length = config.max_text_length
93
+
94
  if self.streaming:
95
  # we need to shuffle early in streaming mode
96
  if hasattr(self, "train_dataset"):
src/dalle_mini/model/configuration.py CHANGED
@@ -59,6 +59,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
59
  is_encoder_decoder=True,
60
  forced_eos_token_id=None,
61
  tie_word_embeddings=False, # different modalities and sizes
 
62
  **kwargs,
63
  ):
64
  self.normalize_text = normalize_text
@@ -87,28 +88,28 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
87
  scale_embedding # scale factor will be sqrt(d_model) if True
88
  )
89
 
90
- # remove inferred keys to prevent errors when loading config (passed as kwargs)
91
- for k in [
92
- "pad_token_id",
93
- "bos_token_id",
94
- "eos_token_id",
95
- "decoder_start_token_id",
96
- "min_length",
97
- "max_length",
98
- ]:
99
- kwargs.pop(k, None)
100
 
101
  super().__init__(
102
- pad_token_id=image_vocab_size
103
- + 1, # needed to avoid errors during generation (converted to jnp.array)
104
- bos_token_id=image_vocab_size + 1, # set to unreachable values
105
- eos_token_id=image_vocab_size + 1,
106
  is_encoder_decoder=is_encoder_decoder,
107
- decoder_start_token_id=image_vocab_size, # BOS appended to vocab
108
- forced_eos_token_id=forced_eos_token_id,
109
  tie_word_embeddings=tie_word_embeddings,
110
- min_length=image_length + 1,
111
- max_length=image_length + 1,
 
 
 
 
 
 
112
  **kwargs,
113
  )
114
 
 
59
  is_encoder_decoder=True,
60
  forced_eos_token_id=None,
61
  tie_word_embeddings=False, # different modalities and sizes
62
+ do_sample=True,
63
  **kwargs,
64
  ):
65
  self.normalize_text = normalize_text
 
88
  scale_embedding # scale factor will be sqrt(d_model) if True
89
  )
90
 
91
+ # special token id's are appended to vocab if not provided
92
+ decoder_start_token_id = kwargs.pop("decoder_start_token_id", image_vocab_size)
93
+ bos_token_id = kwargs.pop("bos_token_id", image_vocab_size)
94
+ pad_token_id = kwargs.pop("pad_token_id", image_vocab_size)
95
+ eos_token_id = kwargs.pop("eos_token_id", image_vocab_size)
96
+
97
+ # we generate to image_length + 1 (for bos) by default
98
+ min_length = kwargs.pop("min_length", image_length + 1)
99
+ max_length = kwargs.pop("max_length", image_length + 1)
 
100
 
101
  super().__init__(
102
+ # args required in parent class
 
 
 
103
  is_encoder_decoder=is_encoder_decoder,
 
 
104
  tie_word_embeddings=tie_word_embeddings,
105
+ forced_eos_token_id=forced_eos_token_id,
106
+ decoder_start_token_id=decoder_start_token_id,
107
+ bos_token_id=bos_token_id,
108
+ pad_token_id=pad_token_id,
109
+ eos_token_id=eos_token_id,
110
+ min_length=min_length,
111
+ max_length=max_length,
112
+ do_sample=do_sample,
113
  **kwargs,
114
  )
115
 
src/dalle_mini/model/modeling.py CHANGED
@@ -54,7 +54,7 @@ logger = logging.get_logger(__name__)
54
  class FlaxBartAttention(FlaxBartAttention):
55
  """
56
  Edits:
57
- - causal mask is used only in decoder and considers image_length + 1 (for BOS)
58
  """
59
 
60
  def setup(self) -> None:
@@ -81,7 +81,7 @@ class FlaxBartAttention(FlaxBartAttention):
81
  if self.causal:
82
  # used only in decoder
83
  self.causal_mask = make_causal_mask(
84
- jnp.ones((1, self.config.image_length + 1), dtype="bool"), dtype="bool"
85
  )
86
 
87
 
@@ -240,7 +240,7 @@ class FlaxBartDecoder(FlaxBartDecoder):
240
  """
241
  Edits:
242
  - offset set to 0 (no padding token)
243
- - use image_length + 1 (for BOS) instead of max_position_embeddings
244
  - use custom FlaxBartDecoderLayerCollection
245
  - embed_tokens cannot be None (issue at compile time)
246
  """
@@ -258,7 +258,7 @@ class FlaxBartDecoder(FlaxBartDecoder):
258
  # and adjust num_embeddings appropriately. Other models don't have this hack
259
  self.offset = 0
260
  self.embed_positions = nn.Embed(
261
- self.config.image_length + 1 + self.offset, # image length + 1 for BOS
262
  embed_dim,
263
  embedding_init=jax.nn.initializers.normal(self.config.init_std),
264
  )
 
54
  class FlaxBartAttention(FlaxBartAttention):
55
  """
56
  Edits:
57
+ - causal mask is used only in decoder and considers image_length
58
  """
59
 
60
  def setup(self) -> None:
 
81
  if self.causal:
82
  # used only in decoder
83
  self.causal_mask = make_causal_mask(
84
+ jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool"
85
  )
86
 
87
 
 
240
  """
241
  Edits:
242
  - offset set to 0 (no padding token)
243
+ - use image_length instead of max_position_embeddings
244
  - use custom FlaxBartDecoderLayerCollection
245
  - embed_tokens cannot be None (issue at compile time)
246
  """
 
258
  # and adjust num_embeddings appropriately. Other models don't have this hack
259
  self.offset = 0
260
  self.embed_positions = nn.Embed(
261
+ self.config.image_length + self.offset, # image length for BOS
262
  embed_dim,
263
  embedding_init=jax.nn.initializers.normal(self.config.init_std),
264
  )
tools/train/train.py CHANGED
@@ -99,7 +99,7 @@ class ModelArguments:
99
 
100
  def __post_init__(self):
101
  if self.restore_state:
102
- assert (
103
  "/model-" in self.model_name_or_path
104
  ), "Restoring state only available with W&B artifact reference"
105
  self.state_artifact = self.model_name_or_path.replace(
@@ -222,12 +222,13 @@ class TrainingArguments:
222
  )
223
 
224
  per_device_train_batch_size: int = field(
225
- default=8, metadata={"help": "Batch size per GPU/TPU/CPU for training."}
 
226
  )
227
  per_device_eval_batch_size: Optional[int] = field(
228
  default=None,
229
  metadata={
230
- "help": "Batch size per GPU/TPU/CPU for evaluation. Same as training batch size if not set."
231
  },
232
  )
233
 
@@ -523,12 +524,7 @@ def main():
523
  # Preprocessing the datasets.
524
  # We need to normalize and tokenize inputs and targets.
525
 
526
- dataset.preprocess(
527
- tokenizer=tokenizer,
528
- decoder_start_token_id=model.config.decoder_start_token_id,
529
- normalize_text=model.config.normalize_text,
530
- max_length=model.config.max_text_length,
531
- )
532
 
533
  # Initialize our training
534
  rng = jax.random.PRNGKey(training_args.seed_model)
@@ -874,9 +870,17 @@ def main():
874
 
875
  # Define eval fn
876
  def eval_step(state, batch):
877
- batch, labels = batch.pop("labels")
878
- logits = model(**batch, params=state.params, train=False)[0]
879
- loss = loss_fn(logits, labels)
 
 
 
 
 
 
 
 
880
  return loss
881
 
882
  # Create parallel version of the train and eval step
@@ -946,7 +950,18 @@ def main():
946
  leave=False,
947
  total=eval_steps,
948
  ):
949
- # freeze batch to pass safely to JAX transforms
 
 
 
 
 
 
 
 
 
 
 
950
  batch = freeze(batch)
951
  # accumulate losses async
952
  eval_loss.append(p_eval_step(state, batch))
 
99
 
100
  def __post_init__(self):
101
  if self.restore_state:
102
+ assert self.model_name_or_path is not None and (
103
  "/model-" in self.model_name_or_path
104
  ), "Restoring state only available with W&B artifact reference"
105
  self.state_artifact = self.model_name_or_path.replace(
 
222
  )
223
 
224
  per_device_train_batch_size: int = field(
225
+ default=8,
226
+ metadata={"help": "Batch size per data parallel device for training."},
227
  )
228
  per_device_eval_batch_size: Optional[int] = field(
229
  default=None,
230
  metadata={
231
+ "help": "Batch size per data parallel device for evaluation. Same as training batch size if not set."
232
  },
233
  )
234
 
 
524
  # Preprocessing the datasets.
525
  # We need to normalize and tokenize inputs and targets.
526
 
527
+ dataset.preprocess(tokenizer=tokenizer, config=model.config)
 
 
 
 
 
528
 
529
  # Initialize our training
530
  rng = jax.random.PRNGKey(training_args.seed_model)
 
870
 
871
  # Define eval fn
872
  def eval_step(state, batch):
873
+ def compute_eval_loss(batch):
874
+ batch, labels = batch.pop("labels")
875
+ logits = state.apply_fn(**batch, params=state.params, train=False)[0]
876
+ return loss_fn(logits, labels)
877
+
878
+ # calculate loss independently per dp_device
879
+ loss = jax.vmap(compute_eval_loss, in_axes=(0,), out_axes=0)(batch)
880
+ # ensure they are sharded over dp devices
881
+ loss = with_sharding_constraint(loss, PartitionSpec("batch"))
882
+ # average across all devices
883
+ loss = jnp.mean(loss)
884
  return loss
885
 
886
  # Create parallel version of the train and eval step
 
950
  leave=False,
951
  total=eval_steps,
952
  ):
953
+ # reshape data into (dp_devices, batch_per_dp, ...)
954
+ batch = jax.tree_map(
955
+ lambda x: x.reshape(
956
+ (
957
+ training_args.dp_devices,
958
+ training_args.per_device_eval_batch_size,
959
+ )
960
+ + x.shape[1:]
961
+ ),
962
+ batch,
963
+ )
964
+ # freeze batch to pass safely to jax transforms
965
  batch = freeze(batch)
966
  # accumulate losses async
967
  eval_loss.append(p_eval_step(state, batch))