boris commited on
Commit
803c7df
1 Parent(s): 6aa30f5

feat: use model definition

Browse files
Files changed (2) hide show
  1. dalle_mini/model.py +23 -22
  2. dev/seq2seq/run_seq2seq_flax.py +56 -132
dalle_mini/model.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import jax
3
  import flax.linen as nn
4
 
@@ -7,25 +6,14 @@ from transformers.models.bart.modeling_flax_bart import (
7
  FlaxBartForConditionalGenerationModule,
8
  FlaxBartForConditionalGeneration,
9
  FlaxBartEncoder,
10
- FlaxBartDecoder
11
  )
12
 
13
  from transformers import BartConfig
14
 
15
 
16
- # Model hyperparameters, for convenience
17
- OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
18
- OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
19
- BOS_TOKEN_ID = 16384
20
- BASE_MODEL = 'facebook/bart-large-cnn' # we currently have issues with bart-large
21
-
22
-
23
  class CustomFlaxBartModule(FlaxBartModule):
24
  def setup(self):
25
- # check config is valid, otherwise set default values
26
- self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
27
- self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH)
28
-
29
  # we keep shared to easily load pre-trained weights
30
  self.shared = nn.Embed(
31
  self.config.vocab_size,
@@ -35,32 +23,45 @@ class CustomFlaxBartModule(FlaxBartModule):
35
  )
36
  # a separate embedding is used for the decoder
37
  self.decoder_embed = nn.Embed(
38
- self.config.vocab_size_output,
39
  self.config.d_model,
40
  embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
41
  dtype=self.dtype,
42
  )
43
- self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
 
 
44
 
45
  # the decoder has a different config
46
  decoder_config = BartConfig(self.config.to_dict())
47
- decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder
48
- decoder_config.vocab_size = self.config.vocab_size_output
49
- self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
 
 
 
 
 
50
 
51
- class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
 
 
52
  def setup(self):
53
  # check config is valid, otherwise set default values
54
- self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
 
55
 
56
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
57
  self.lm_head = nn.Dense(
58
- self.config.vocab_size_output,
59
  use_bias=False,
60
  dtype=self.dtype,
61
  kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
62
  )
63
- self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output))
 
 
 
64
 
65
  class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
66
  module_class = CustomFlaxBartForConditionalGenerationModule
 
 
1
  import jax
2
  import flax.linen as nn
3
 
 
6
  FlaxBartForConditionalGenerationModule,
7
  FlaxBartForConditionalGeneration,
8
  FlaxBartEncoder,
9
+ FlaxBartDecoder,
10
  )
11
 
12
  from transformers import BartConfig
13
 
14
 
 
 
 
 
 
 
 
15
  class CustomFlaxBartModule(FlaxBartModule):
16
  def setup(self):
 
 
 
 
17
  # we keep shared to easily load pre-trained weights
18
  self.shared = nn.Embed(
19
  self.config.vocab_size,
 
23
  )
24
  # a separate embedding is used for the decoder
25
  self.decoder_embed = nn.Embed(
26
+ self.config.image_vocab_size + 1,
27
  self.config.d_model,
28
  embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
29
  dtype=self.dtype,
30
  )
31
+ self.encoder = FlaxBartEncoder(
32
+ self.config, dtype=self.dtype, embed_tokens=self.shared
33
+ )
34
 
35
  # the decoder has a different config
36
  decoder_config = BartConfig(self.config.to_dict())
37
+ decoder_config.max_position_embeddings = (
38
+ self.config.image_length + 1 # image tokens + BOS
39
+ )
40
+ decoder_config.vocab_size = self.config.image_vocab_size + 1
41
+ self.decoder = FlaxBartDecoder(
42
+ decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed
43
+ )
44
+
45
 
46
+ class CustomFlaxBartForConditionalGenerationModule(
47
+ FlaxBartForConditionalGenerationModule
48
+ ):
49
  def setup(self):
50
  # check config is valid, otherwise set default values
51
+ # TODO: simplify with custom config class
52
+ self.config.text_normalized = True / False
53
 
54
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
55
  self.lm_head = nn.Dense(
56
+ self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
57
  use_bias=False,
58
  dtype=self.dtype,
59
  kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
60
  )
61
+ self.final_logits_bias = self.param(
62
+ "final_logits_bias", self.bias_init, (1, self.config.image_vocab_size + 1)
63
+ )
64
+
65
 
66
  class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
67
  module_class = CustomFlaxBartForConditionalGenerationModule
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -17,10 +17,9 @@
17
  Fine-tuning the library models for seq2seq, text to image.
18
  Script adapted from run_summarization_flax.py
19
  """
20
- # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
21
 
22
  import os
23
- import logging as pylogging # To avoid collision with transformers.utils.logging
24
  import sys
25
  from dataclasses import dataclass, field
26
  from pathlib import Path
@@ -44,7 +43,6 @@ from flax.training import train_state
44
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
45
  from transformers import (
46
  AutoTokenizer,
47
- FlaxBartForConditionalGeneration,
48
  HfArgumentParser,
49
  TrainingArguments,
50
  )
@@ -53,16 +51,9 @@ from transformers.models.bart.modeling_flax_bart import *
53
  import wandb
54
 
55
  from dalle_mini.text import TextNormalizer
 
56
 
57
- logger = pylogging.getLogger(__name__)
58
-
59
-
60
- # Model hyperparameters, for convenience
61
- # TODO: the model has now it's own definition file and should be imported
62
- OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
63
- OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
64
- BOS_TOKEN_ID = 16384
65
- BASE_MODEL = "facebook/bart-large-cnn" # we currently have issues with bart-large
66
 
67
 
68
  @dataclass
@@ -72,24 +63,30 @@ class ModelArguments:
72
  """
73
 
74
  model_name_or_path: Optional[str] = field(
75
- default=BASE_MODEL,
76
  metadata={
77
  "help": "The model checkpoint for weights initialization."
78
  "Don't set if you want to train a model from scratch."
79
  },
80
  )
81
- config_name: Optional[str] = field(
82
  default=None,
83
- metadata={
84
- "help": "Pretrained config name or path if not the same as model_name"
85
- },
 
 
86
  )
87
- use_fast_tokenizer: bool = field(
88
- default=True,
89
  metadata={
90
- "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
91
  },
92
  )
 
 
 
 
93
  dtype: Optional[str] = field(
94
  default="float32",
95
  metadata={
@@ -158,22 +155,6 @@ class DataTrainingArguments:
158
  default=False,
159
  metadata={"help": "Whether to use decay in the learning rate scheduler."},
160
  )
161
- max_target_length: Optional[int] = field(
162
- default=OUTPUT_LENGTH,
163
- metadata={
164
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
165
- "than this will be truncated, sequences shorter will be padded."
166
- },
167
- )
168
- val_max_target_length: Optional[int] = field(
169
- default=OUTPUT_LENGTH,
170
- metadata={
171
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
172
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
173
- "This argument is also used to override the `max_length` param of `model.generate`, which is used "
174
- "during evaluation."
175
- },
176
- )
177
  max_train_samples: Optional[int] = field(
178
  default=None,
179
  metadata={
@@ -188,10 +169,6 @@ class DataTrainingArguments:
188
  "value if set."
189
  },
190
  )
191
- normalize_text: bool = field(
192
- default=False,
193
- metadata={"help": "Normalize/Simplify text"},
194
- )
195
  preprocessing_num_workers: Optional[int] = field(
196
  default=80, # ensure we have the same datasets cached data and avoid using too much space
197
  metadata={"help": "The number of processes to use for the preprocessing."},
@@ -243,8 +220,6 @@ class DataTrainingArguments:
243
  "json",
244
  "jsonl",
245
  ], "`validation_file` should be a tsv, csv or json file."
246
- if self.val_max_target_length is None:
247
- self.val_max_target_length = self.max_target_length
248
  if self.streaming and (self.len_train is None or self.len_eval is None):
249
  raise ValueError(
250
  "Streaming requires providing length of training and validation datasets"
@@ -273,70 +248,6 @@ class TrainState(train_state.TrainState):
273
  return self.replace(step=new_step, opt_state=new_opt_state)
274
 
275
 
276
- class CustomFlaxBartModule(FlaxBartModule):
277
- def setup(self):
278
- # check config is valid, otherwise set default values
279
- self.config.vocab_size_output = getattr(
280
- self.config, "vocab_size_output", OUTPUT_VOCAB_SIZE
281
- )
282
- self.config.max_position_embeddings_decoder = getattr(
283
- self.config, "max_position_embeddings_decoder", OUTPUT_LENGTH
284
- )
285
-
286
- # we keep shared to easily load pre-trained weights
287
- self.shared = nn.Embed(
288
- self.config.vocab_size,
289
- self.config.d_model,
290
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
291
- dtype=self.dtype,
292
- )
293
- # a separate embedding is used for the decoder
294
- self.decoder_embed = nn.Embed(
295
- self.config.vocab_size_output,
296
- self.config.d_model,
297
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
298
- dtype=self.dtype,
299
- )
300
- self.encoder = FlaxBartEncoder(
301
- self.config, dtype=self.dtype, embed_tokens=self.shared
302
- )
303
-
304
- # the decoder has a different config
305
- decoder_config = BartConfig(self.config.to_dict())
306
- decoder_config.max_position_embeddings = (
307
- self.config.max_position_embeddings_decoder
308
- )
309
- decoder_config.vocab_size = self.config.vocab_size_output
310
- self.decoder = FlaxBartDecoder(
311
- decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed
312
- )
313
-
314
-
315
- class CustomFlaxBartForConditionalGenerationModule(
316
- FlaxBartForConditionalGenerationModule
317
- ):
318
- def setup(self):
319
- # check config is valid, otherwise set default values
320
- self.config.vocab_size_output = getattr(
321
- self.config, "vocab_size_output", OUTPUT_VOCAB_SIZE
322
- )
323
-
324
- self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
325
- self.lm_head = nn.Dense(
326
- self.config.vocab_size_output,
327
- use_bias=False,
328
- dtype=self.dtype,
329
- kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
330
- )
331
- self.final_logits_bias = self.param(
332
- "final_logits_bias", self.bias_init, (1, self.config.vocab_size_output)
333
- )
334
-
335
-
336
- class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
337
- module_class = CustomFlaxBartForConditionalGenerationModule
338
-
339
-
340
  def data_loader(
341
  rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False
342
  ):
@@ -440,13 +351,13 @@ def main():
440
  )
441
 
442
  # Make one log on every process with the configuration for debugging.
443
- pylogging.basicConfig(
444
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
445
  datefmt="%m/%d/%Y %H:%M:%S",
446
- level=pylogging.INFO,
447
  )
448
  # Setup logging, we only want one process per machine to log things on the screen.
449
- logger.setLevel(pylogging.INFO if jax.process_index() == 0 else pylogging.ERROR)
450
  if jax.process_index() == 0:
451
  datasets.utils.logging.set_verbosity_warning()
452
  transformers.utils.logging.set_verbosity_info()
@@ -483,44 +394,57 @@ def main():
483
  if model_args.from_checkpoint is not None:
484
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
485
  artifact_dir = artifact.download()
 
 
486
  model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
487
 
488
  # load tokenizer
489
  tokenizer = AutoTokenizer.from_pretrained(
490
  artifact_dir,
491
- use_fast=model_args.use_fast_tokenizer,
492
  )
493
 
494
  else:
495
  # Set up our new model config
 
496
  config = BartConfig.from_pretrained(model_args.model_name_or_path)
497
- config.tie_word_embeddings = False
498
- config.decoder_start_token_id = BOS_TOKEN_ID # for first token
499
- config.bos_token_id = (
500
- BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
501
- )
502
- config.pos_token_id = (
503
- BOS_TOKEN_ID # should not be needed (as we generate until max_length)
504
- )
505
- config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
506
  config.forced_bos_token_id = None # we don't need this token
507
  config.forced_eos_token_id = None # we don't need this token
508
- config.force_bos_token_to_be_generated = (
509
- False # otherwise it sets bos_token_id at loading
510
- )
511
- config.min_length = data_args.max_target_length
512
- config.max_length = data_args.max_target_length
 
 
 
 
 
 
 
 
513
 
514
  # Create a custom model and initialize it randomly
515
- model = CustomFlaxBartForConditionalGeneration(
516
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
517
  )
518
 
519
  # Load tokenizer
520
- tokenizer = AutoTokenizer.from_pretrained(
521
- model_args.model_name_or_path,
522
- use_fast=model_args.use_fast_tokenizer,
523
- )
 
 
 
 
 
524
 
525
  print(f"TPUs: {jax.device_count()}")
526
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
@@ -543,7 +467,7 @@ def main():
543
  shifted_input_ids[:, 0] = decoder_start_token_id
544
  return shifted_input_ids
545
 
546
- text_normalizer = TextNormalizer() if data_args.normalize_text else None
547
 
548
  def normalize_text(example):
549
  example[text_column] = text_normalizer(example[text_column])
@@ -590,7 +514,7 @@ def main():
590
  )
591
  if data_args.streaming:
592
  train_dataset = train_dataset.shuffle(1000, training_args.seed)
593
- if data_args.normalize_text:
594
  train_dataset = (
595
  train_dataset.map(normalize_text)
596
  if data_args.streaming
@@ -627,7 +551,7 @@ def main():
627
  if data_args.streaming
628
  else eval_dataset.select(range(data_args.max_train_samples))
629
  )
630
- if data_args.normalize_text:
631
  eval_dataset = (
632
  eval_dataset.map(normalize_text)
633
  if data_args.streaming
 
17
  Fine-tuning the library models for seq2seq, text to image.
18
  Script adapted from run_summarization_flax.py
19
  """
 
20
 
21
  import os
22
+ import logging
23
  import sys
24
  from dataclasses import dataclass, field
25
  from pathlib import Path
 
43
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
44
  from transformers import (
45
  AutoTokenizer,
 
46
  HfArgumentParser,
47
  TrainingArguments,
48
  )
 
51
  import wandb
52
 
53
  from dalle_mini.text import TextNormalizer
54
+ from dalle_mini.model import CustomFlaxBartForConditionalGeneration
55
 
56
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
57
 
58
 
59
  @dataclass
 
63
  """
64
 
65
  model_name_or_path: Optional[str] = field(
66
+ default=None,
67
  metadata={
68
  "help": "The model checkpoint for weights initialization."
69
  "Don't set if you want to train a model from scratch."
70
  },
71
  )
72
+ image_vocab_size: Optional[int] = field(
73
  default=None,
74
+ metadata={"help": "Vocab size of image encoder"},
75
+ )
76
+ image_length: Optional[int] = field(
77
+ default=None,
78
+ metadata={"help": "Number of tokens per image"},
79
  )
80
+ tokenizer_name: Optional[str] = field(
81
+ default=None,
82
  metadata={
83
+ "help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
84
  },
85
  )
86
+ normalize_text: bool = field(
87
+ default=False,
88
+ metadata={"help": "Whether to normalize text or not."},
89
+ )
90
  dtype: Optional[str] = field(
91
  default="float32",
92
  metadata={
 
155
  default=False,
156
  metadata={"help": "Whether to use decay in the learning rate scheduler."},
157
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  max_train_samples: Optional[int] = field(
159
  default=None,
160
  metadata={
 
169
  "value if set."
170
  },
171
  )
 
 
 
 
172
  preprocessing_num_workers: Optional[int] = field(
173
  default=80, # ensure we have the same datasets cached data and avoid using too much space
174
  metadata={"help": "The number of processes to use for the preprocessing."},
 
220
  "json",
221
  "jsonl",
222
  ], "`validation_file` should be a tsv, csv or json file."
 
 
223
  if self.streaming and (self.len_train is None or self.len_eval is None):
224
  raise ValueError(
225
  "Streaming requires providing length of training and validation datasets"
 
248
  return self.replace(step=new_step, opt_state=new_opt_state)
249
 
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  def data_loader(
252
  rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False
253
  ):
 
351
  )
352
 
353
  # Make one log on every process with the configuration for debugging.
354
+ logging.basicConfig(
355
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
356
  datefmt="%m/%d/%Y %H:%M:%S",
357
+ level=logging.INFO,
358
  )
359
  # Setup logging, we only want one process per machine to log things on the screen.
360
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
361
  if jax.process_index() == 0:
362
  datasets.utils.logging.set_verbosity_warning()
363
  transformers.utils.logging.set_verbosity_info()
 
394
  if model_args.from_checkpoint is not None:
395
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
396
  artifact_dir = artifact.download()
397
+
398
+ # load model
399
  model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
400
 
401
  # load tokenizer
402
  tokenizer = AutoTokenizer.from_pretrained(
403
  artifact_dir,
404
+ use_fast=True,
405
  )
406
 
407
  else:
408
  # Set up our new model config
409
+ # TODO: simplify with custom config class
410
  config = BartConfig.from_pretrained(model_args.model_name_or_path)
411
+ config.image_vocab_size = model_args.image_vocab_size
412
+ config.image_length = model_args.image_length
413
+ # we append decoder bos to image vocab
414
+ config.decoder_start_token_id = config.image_vocab_size
415
+ # ensure we don't generate bos (in addition to decoder start token)
416
+ config.force_bos_token_to_be_generated = False
 
 
 
417
  config.forced_bos_token_id = None # we don't need this token
418
  config.forced_eos_token_id = None # we don't need this token
419
+
420
+ config.tie_word_embeddings = False
421
+ config.min_length = model_args.image_length + 1
422
+ config.max_length = model_args.image_length + 1
423
+
424
+ # below tokens need to be set to avoid error during generation (converted to jnp.array)
425
+ # they are not expected to be used and are set to unreachable token id
426
+ config.bos_token_id = config.image_vocab_size + 1
427
+ config.pos_token_id = config.image_vocab_size + 1
428
+ config.eos_token_id = config.image_vocab_size + 1
429
+
430
+ # save whether we normalize the text
431
+ config.normalize_text = model_args.normalize_text
432
 
433
  # Create a custom model and initialize it randomly
434
+ model = CustomFlaxBartForConditionalGeneration.from_config(
435
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
436
  )
437
 
438
  # Load tokenizer
439
+ if model_args.tokenizer_name is not None:
440
+ tokenizer = AutoTokenizer.from_pretrained(
441
+ model_args.tokenizer_name, use_fast=True
442
+ )
443
+ else:
444
+ tokenizer = AutoTokenizer.from_pretrained(
445
+ model_args.model_name_or_path,
446
+ use_fast=True,
447
+ )
448
 
449
  print(f"TPUs: {jax.device_count()}")
450
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
 
467
  shifted_input_ids[:, 0] = decoder_start_token_id
468
  return shifted_input_ids
469
 
470
+ text_normalizer = TextNormalizer() if model.config.normalize_text else None
471
 
472
  def normalize_text(example):
473
  example[text_column] = text_normalizer(example[text_column])
 
514
  )
515
  if data_args.streaming:
516
  train_dataset = train_dataset.shuffle(1000, training_args.seed)
517
+ if model.config.normalize_text:
518
  train_dataset = (
519
  train_dataset.map(normalize_text)
520
  if data_args.streaming
 
551
  if data_args.streaming
552
  else eval_dataset.select(range(data_args.max_train_samples))
553
  )
554
+ if model.config.normalize_text:
555
  eval_dataset = (
556
  eval_dataset.map(normalize_text)
557
  if data_args.streaming