boris commited on
Commit
2816f98
·
unverified ·
2 Parent(s): 6aa30f5 c55ecf8

Merge pull request #107 from borisdayma/feat-seq2seq

Browse files
dalle_mini/model.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import jax
3
  import flax.linen as nn
4
 
@@ -7,60 +6,56 @@ from transformers.models.bart.modeling_flax_bart import (
7
  FlaxBartForConditionalGenerationModule,
8
  FlaxBartForConditionalGeneration,
9
  FlaxBartEncoder,
10
- FlaxBartDecoder
11
  )
12
 
13
  from transformers import BartConfig
14
 
15
 
16
- # Model hyperparameters, for convenience
17
- OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
18
- OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
19
- BOS_TOKEN_ID = 16384
20
- BASE_MODEL = 'facebook/bart-large-cnn' # we currently have issues with bart-large
21
-
22
-
23
  class CustomFlaxBartModule(FlaxBartModule):
24
  def setup(self):
25
- # check config is valid, otherwise set default values
26
- self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
27
- self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH)
28
-
29
  # we keep shared to easily load pre-trained weights
30
  self.shared = nn.Embed(
31
  self.config.vocab_size,
32
  self.config.d_model,
33
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
34
- dtype=self.dtype,
35
  )
36
  # a separate embedding is used for the decoder
37
  self.decoder_embed = nn.Embed(
38
- self.config.vocab_size_output,
39
  self.config.d_model,
40
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
41
- dtype=self.dtype,
 
 
42
  )
43
- self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
44
 
45
  # the decoder has a different config
 
46
  decoder_config = BartConfig(self.config.to_dict())
47
- decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder
48
- decoder_config.vocab_size = self.config.vocab_size_output
49
- self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
 
 
 
 
50
 
51
- class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
52
- def setup(self):
53
- # check config is valid, otherwise set default values
54
- self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
55
 
 
 
 
 
56
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
57
  self.lm_head = nn.Dense(
58
- self.config.vocab_size_output,
59
  use_bias=False,
60
- dtype=self.dtype,
61
- kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
62
  )
63
- self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output))
 
 
 
64
 
65
  class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
66
  module_class = CustomFlaxBartForConditionalGenerationModule
 
 
1
  import jax
2
  import flax.linen as nn
3
 
 
6
  FlaxBartForConditionalGenerationModule,
7
  FlaxBartForConditionalGeneration,
8
  FlaxBartEncoder,
9
+ FlaxBartDecoder,
10
  )
11
 
12
  from transformers import BartConfig
13
 
14
 
 
 
 
 
 
 
 
15
  class CustomFlaxBartModule(FlaxBartModule):
16
  def setup(self):
 
 
 
 
17
  # we keep shared to easily load pre-trained weights
18
  self.shared = nn.Embed(
19
  self.config.vocab_size,
20
  self.config.d_model,
21
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
 
22
  )
23
  # a separate embedding is used for the decoder
24
  self.decoder_embed = nn.Embed(
25
+ self.config.image_vocab_size + 1,
26
  self.config.d_model,
27
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
28
+ )
29
+ self.encoder = FlaxBartEncoder(
30
+ self.config, dtype=self.dtype, embed_tokens=self.shared
31
  )
 
32
 
33
  # the decoder has a different config
34
+ # TODO: should not be needed once we have custom config/module
35
  decoder_config = BartConfig(self.config.to_dict())
36
+ decoder_config.max_position_embeddings = (
37
+ self.config.image_length + 1 # image tokens + BOS
38
+ )
39
+ decoder_config.vocab_size = self.config.image_vocab_size + 1
40
+ self.decoder = FlaxBartDecoder(
41
+ decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed
42
+ )
43
 
 
 
 
 
44
 
45
+ class CustomFlaxBartForConditionalGenerationModule(
46
+ FlaxBartForConditionalGenerationModule
47
+ ):
48
+ def setup(self):
49
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
50
  self.lm_head = nn.Dense(
51
+ self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
52
  use_bias=False,
53
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
 
54
  )
55
+ self.final_logits_bias = self.param(
56
+ "final_logits_bias", self.bias_init, (1, self.config.image_vocab_size + 1)
57
+ )
58
+
59
 
60
  class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
61
  module_class = CustomFlaxBartForConditionalGenerationModule
dev/inference/samples.txt CHANGED
@@ -24,7 +24,6 @@ underwater cathedral
24
  a photo of a fantasy version of New York City
25
  a picture of fantasy kingdoms
26
  a volcano erupting next to San Francisco golden gate bridge
27
- big wave destroying a city
28
  Paris in a far future, futuristic Paris
29
  real painting of an alien from Monet
30
  the communist statue of liberty
@@ -54,16 +53,16 @@ a long line of green blocks on a beach at subset
54
  a long line of peaches on a beach at sunset
55
  a picture of a castle from minecraft
56
  a cute pikachu teapot
57
- an illustration of pikachu sitting on a bench
58
- mario is jumping over a zebra during the sunset
59
  famous anime hero
60
  star wars concept art
61
  a cartoon of a superhero bear
62
  an illustration of a cute skeleton wearing a blue hoodie
63
  illustration of a baby shark swimming around corals
 
64
  Cartoon of a carrot with big eyes
65
  logo of a robot wearing glasses and reading a book
66
- a cactus lifting weights
67
  illustration of a cactus lifting weigths
68
  logo of a cactus lifting weights
69
  a photo of a camera from the future
@@ -72,7 +71,6 @@ a collection of glasses is sitting on a table
72
  a painting of a capybara sitting on a mountain during fall in surrealist style
73
  a pentagonal green clock
74
  a pixel art illustration of an eagle sitting in a field in the afternoon
75
- a professional high-quality emoji of a lovestruck cup of boba
76
  a small red block sitting on a large green block
77
  a storefront that has the word 'openai' written on it
78
  a tatoo of a black broccoli
@@ -88,10 +86,7 @@ urinals are lined up in a jungle
88
  a muscular banana sitting upright on a bench smoking watching a banana on television, high definition photography
89
  a human face
90
  a person is holding a phone and a waterbottle, running a marathon
91
- a photograph of Ellen G. White
92
  Young woman riding her bike through the forest
93
- a portrait of a nightmare creature watching at you
94
- a white room full of a black substance
95
  the best soccer team of the world
96
  the best basketball team of the world
97
  the best football team of the world
@@ -100,6 +95,7 @@ sad, sadness
100
  the representation of infinity
101
  the end of the world
102
  the last sunrise on earth
 
103
  an avocado armchair
104
  an armchair in the shape of an avocado
105
  illustration of an avocado armchair
@@ -109,4 +105,3 @@ an avocado armchair flying into space
109
  a cute avocado armchair singing karaoke on stage in front of a crowd of strawberry shaped lamps
110
  an illustration of an avocado in a christmas sweater staring at its reflection in a mirror
111
  illustration of an avocado armchair getting married to a pineapple
112
- an illustration of an avocado in a beanie riding a motorcycle
 
24
  a photo of a fantasy version of New York City
25
  a picture of fantasy kingdoms
26
  a volcano erupting next to San Francisco golden gate bridge
 
27
  Paris in a far future, futuristic Paris
28
  real painting of an alien from Monet
29
  the communist statue of liberty
 
53
  a long line of peaches on a beach at sunset
54
  a picture of a castle from minecraft
55
  a cute pikachu teapot
56
+ an illustration of pikachu sitting on a bench eating an ice cream
57
+ mario is jumping over a zebra
58
  famous anime hero
59
  star wars concept art
60
  a cartoon of a superhero bear
61
  an illustration of a cute skeleton wearing a blue hoodie
62
  illustration of a baby shark swimming around corals
63
+ an illustration of an avocado in a beanie riding a motorcycle
64
  Cartoon of a carrot with big eyes
65
  logo of a robot wearing glasses and reading a book
 
66
  illustration of a cactus lifting weigths
67
  logo of a cactus lifting weights
68
  a photo of a camera from the future
 
71
  a painting of a capybara sitting on a mountain during fall in surrealist style
72
  a pentagonal green clock
73
  a pixel art illustration of an eagle sitting in a field in the afternoon
 
74
  a small red block sitting on a large green block
75
  a storefront that has the word 'openai' written on it
76
  a tatoo of a black broccoli
 
86
  a muscular banana sitting upright on a bench smoking watching a banana on television, high definition photography
87
  a human face
88
  a person is holding a phone and a waterbottle, running a marathon
 
89
  Young woman riding her bike through the forest
 
 
90
  the best soccer team of the world
91
  the best basketball team of the world
92
  the best football team of the world
 
95
  the representation of infinity
96
  the end of the world
97
  the last sunrise on earth
98
+ a portrait of a nightmare creature watching at you
99
  an avocado armchair
100
  an armchair in the shape of an avocado
101
  illustration of an avocado armchair
 
105
  a cute avocado armchair singing karaoke on stage in front of a crowd of strawberry shaped lamps
106
  an illustration of an avocado in a christmas sweater staring at its reflection in a mirror
107
  illustration of an avocado armchair getting married to a pineapple
 
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -17,11 +17,11 @@
17
  Fine-tuning the library models for seq2seq, text to image.
18
  Script adapted from run_summarization_flax.py
19
  """
20
- # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
21
 
22
  import os
23
- import logging as pylogging # To avoid collision with transformers.utils.logging
24
  import sys
 
25
  from dataclasses import dataclass, field
26
  from pathlib import Path
27
  from typing import Callable, Optional
@@ -38,31 +38,21 @@ import optax
38
  import transformers
39
  from flax import jax_utils, traverse_util
40
  from flax.serialization import from_bytes, to_bytes
41
- import flax.linen as nn
42
  from flax.jax_utils import unreplicate
43
  from flax.training import train_state
44
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
45
  from transformers import (
46
  AutoTokenizer,
47
- FlaxBartForConditionalGeneration,
48
  HfArgumentParser,
49
- TrainingArguments,
50
  )
51
- from transformers.models.bart.modeling_flax_bart import *
52
 
53
  import wandb
54
 
55
  from dalle_mini.text import TextNormalizer
 
56
 
57
- logger = pylogging.getLogger(__name__)
58
-
59
-
60
- # Model hyperparameters, for convenience
61
- # TODO: the model has now it's own definition file and should be imported
62
- OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
63
- OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
64
- BOS_TOKEN_ID = 16384
65
- BASE_MODEL = "facebook/bart-large-cnn" # we currently have issues with bart-large
66
 
67
 
68
  @dataclass
@@ -72,36 +62,36 @@ class ModelArguments:
72
  """
73
 
74
  model_name_or_path: Optional[str] = field(
75
- default=BASE_MODEL,
76
  metadata={
77
  "help": "The model checkpoint for weights initialization."
78
  "Don't set if you want to train a model from scratch."
79
  },
80
  )
81
- config_name: Optional[str] = field(
82
  default=None,
83
- metadata={
84
- "help": "Pretrained config name or path if not the same as model_name"
85
- },
86
  )
87
- use_fast_tokenizer: bool = field(
88
- default=True,
 
 
 
 
89
  metadata={
90
- "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
91
  },
92
  )
 
 
 
 
93
  dtype: Optional[str] = field(
94
  default="float32",
95
  metadata={
96
  "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
97
  },
98
  )
99
- from_checkpoint: Optional[str] = field(
100
- default=None,
101
- metadata={
102
- "help": "Loads a pretrained wandb checkpoint. Use artifact reference."
103
- },
104
- )
105
 
106
 
107
  @dataclass
@@ -139,13 +129,11 @@ class DataTrainingArguments:
139
  default=False,
140
  metadata={"help": "Whether to stream the dataset."},
141
  )
142
- len_train: Optional[int] = field(
143
- default=None,
144
- metadata={"help": "Length of training dataset, required for streaming"},
145
- )
146
- len_eval: Optional[int] = field(
147
- default=None,
148
- metadata={"help": "Length of validation dataset, required for streaming"},
149
  )
150
  max_source_length: Optional[int] = field(
151
  default=128,
@@ -154,26 +142,6 @@ class DataTrainingArguments:
154
  "than this will be truncated, sequences shorter will be padded."
155
  },
156
  )
157
- no_decay: bool = field(
158
- default=False,
159
- metadata={"help": "Whether to use decay in the learning rate scheduler."},
160
- )
161
- max_target_length: Optional[int] = field(
162
- default=OUTPUT_LENGTH,
163
- metadata={
164
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
165
- "than this will be truncated, sequences shorter will be padded."
166
- },
167
- )
168
- val_max_target_length: Optional[int] = field(
169
- default=OUTPUT_LENGTH,
170
- metadata={
171
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
172
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
173
- "This argument is also used to override the `max_length` param of `model.generate`, which is used "
174
- "during evaluation."
175
- },
176
- )
177
  max_train_samples: Optional[int] = field(
178
  default=None,
179
  metadata={
@@ -188,71 +156,144 @@ class DataTrainingArguments:
188
  "value if set."
189
  },
190
  )
191
- normalize_text: bool = field(
192
- default=False,
193
- metadata={"help": "Normalize/Simplify text"},
194
- )
195
  preprocessing_num_workers: Optional[int] = field(
196
- default=80, # ensure we have the same datasets cached data and avoid using too much space
197
- metadata={"help": "The number of processes to use for the preprocessing."},
198
- )
199
- source_prefix: Optional[str] = field(
200
  default=None,
201
  metadata={
202
- "help": "A prefix to add before every source text (useful for T5 models)."
203
  },
204
  )
205
  overwrite_cache: bool = field(
206
  default=False,
207
- metadata={"help": "Overwrite the cached training and evaluation sets"},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  )
209
- log_interval: Optional[int] = field(
210
- default=40,
211
- metadata={"help": "Log frequency for metrics"},
212
  )
213
  log_model: bool = field(
214
  default=False,
215
- metadata={"help": "Overwrite the cached training and evaluation sets"},
216
  )
217
- save_model_steps: Optional[int] = field(
218
- default=5000, # about once every 1.5h in our experiments
 
 
 
 
 
 
 
 
219
  metadata={
220
- "help": "For logging the model more frequently. Used only when `log_model` is set."
221
  },
222
  )
223
 
224
- def __post_init__(self):
225
- if self.dataset_repo_or_path is None:
226
- raise ValueError("Need a dataset repository or path.")
227
- if self.train_file is None or self.validation_file is None:
228
- raise ValueError("Need training/validation file.")
229
- else:
230
- if self.train_file is not None:
231
- extension = self.train_file.split(".")[-1]
232
- assert extension in [
233
- "tsv",
234
- "csv",
235
- "json",
236
- "jsonl",
237
- ], "`train_file` should be a tsv, csv or json file."
238
- if self.validation_file is not None:
239
- extension = self.validation_file.split(".")[-1]
240
- assert extension in [
241
- "tsv",
242
- "csv",
243
- "json",
244
- "jsonl",
245
- ], "`validation_file` should be a tsv, csv or json file."
246
- if self.val_max_target_length is None:
247
- self.val_max_target_length = self.max_target_length
248
- if self.streaming and (self.len_train is None or self.len_eval is None):
249
- raise ValueError(
250
- "Streaming requires providing length of training and validation datasets"
251
- )
252
 
253
 
254
  class TrainState(train_state.TrainState):
255
  dropout_rng: jnp.ndarray = None
 
 
 
256
 
257
  def replicate(self):
258
  return jax_utils.replicate(self).replace(
@@ -264,81 +305,23 @@ class TrainState(train_state.TrainState):
264
  with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
265
  new_opt_state = from_bytes(self.opt_state, f.read())
266
 
267
- # restore steps
268
  with (Path(artifact_dir) / "training_state.json").open("r") as f:
269
  training_state = json.load(f)
270
- new_step = training_state["step"]
271
 
272
  # replace state
273
- return self.replace(step=new_step, opt_state=new_opt_state)
274
-
275
-
276
- class CustomFlaxBartModule(FlaxBartModule):
277
- def setup(self):
278
- # check config is valid, otherwise set default values
279
- self.config.vocab_size_output = getattr(
280
- self.config, "vocab_size_output", OUTPUT_VOCAB_SIZE
281
- )
282
- self.config.max_position_embeddings_decoder = getattr(
283
- self.config, "max_position_embeddings_decoder", OUTPUT_LENGTH
284
  )
285
 
286
- # we keep shared to easily load pre-trained weights
287
- self.shared = nn.Embed(
288
- self.config.vocab_size,
289
- self.config.d_model,
290
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
291
- dtype=self.dtype,
292
- )
293
- # a separate embedding is used for the decoder
294
- self.decoder_embed = nn.Embed(
295
- self.config.vocab_size_output,
296
- self.config.d_model,
297
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
298
- dtype=self.dtype,
299
- )
300
- self.encoder = FlaxBartEncoder(
301
- self.config, dtype=self.dtype, embed_tokens=self.shared
302
- )
303
-
304
- # the decoder has a different config
305
- decoder_config = BartConfig(self.config.to_dict())
306
- decoder_config.max_position_embeddings = (
307
- self.config.max_position_embeddings_decoder
308
- )
309
- decoder_config.vocab_size = self.config.vocab_size_output
310
- self.decoder = FlaxBartDecoder(
311
- decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed
312
- )
313
-
314
-
315
- class CustomFlaxBartForConditionalGenerationModule(
316
- FlaxBartForConditionalGenerationModule
317
- ):
318
- def setup(self):
319
- # check config is valid, otherwise set default values
320
- self.config.vocab_size_output = getattr(
321
- self.config, "vocab_size_output", OUTPUT_VOCAB_SIZE
322
- )
323
-
324
- self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
325
- self.lm_head = nn.Dense(
326
- self.config.vocab_size_output,
327
- use_bias=False,
328
- dtype=self.dtype,
329
- kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
330
- )
331
- self.final_logits_bias = self.param(
332
- "final_logits_bias", self.bias_init, (1, self.config.vocab_size_output)
333
- )
334
-
335
-
336
- class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
337
- module_class = CustomFlaxBartForConditionalGenerationModule
338
-
339
 
340
  def data_loader(
341
- rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False
 
 
342
  ):
343
  """
344
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
@@ -346,7 +329,7 @@ def data_loader(
346
  """
347
  steps_per_epoch = len(dataset) // batch_size
348
 
349
- if shuffle:
350
  batch_idx = jax.random.permutation(rng, len(dataset))
351
  else:
352
  batch_idx = jnp.arange(len(dataset))
@@ -375,20 +358,20 @@ def data_loader_streaming(dataset: Dataset, batch_size: int):
375
 
376
 
377
  def create_learning_rate_fn(
378
- train_ds_size: int,
379
- train_batch_size: int,
380
- num_train_epochs: int,
381
  num_warmup_steps: int,
382
  learning_rate: float,
383
- no_decay: bool,
 
384
  ) -> Callable[[int], jnp.array]:
385
  """Returns a linear warmup, linear_decay learning rate function."""
386
- steps_per_epoch = train_ds_size // train_batch_size
387
- num_train_steps = steps_per_epoch * num_train_epochs
 
 
388
  warmup_fn = optax.linear_schedule(
389
  init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
390
  )
391
- if no_decay:
392
  return warmup_fn
393
  decay_fn = optax.linear_schedule(
394
  init_value=learning_rate,
@@ -412,10 +395,7 @@ def wandb_log(metrics, step=None, prefix=None):
412
 
413
 
414
  def main():
415
- # See all possible arguments in src/transformers/training_args.py
416
- # or by passing the --help flag to this script.
417
- # We now keep distinct sets of args, for a cleaner separation of concerns.
418
-
419
  parser = HfArgumentParser(
420
  (ModelArguments, DataTrainingArguments, TrainingArguments)
421
  )
@@ -440,13 +420,13 @@ def main():
440
  )
441
 
442
  # Make one log on every process with the configuration for debugging.
443
- pylogging.basicConfig(
444
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
445
  datefmt="%m/%d/%Y %H:%M:%S",
446
- level=pylogging.INFO,
447
  )
448
  # Setup logging, we only want one process per machine to log things on the screen.
449
- logger.setLevel(pylogging.INFO if jax.process_index() == 0 else pylogging.ERROR)
450
  if jax.process_index() == 0:
451
  datasets.utils.logging.set_verbosity_warning()
452
  transformers.utils.logging.set_verbosity_info()
@@ -457,18 +437,19 @@ def main():
457
  # Set the verbosity to info of the Transformers logger (on main process only):
458
  logger.info(f"Training/evaluation parameters {training_args}")
459
 
460
- # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
461
- # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
462
- # (the dataset will be downloaded automatically from the datasets Hub).
463
- #
464
- data_files = {
465
- "train": data_args.train_file,
466
- "validation": data_args.validation_file,
467
- }
468
  dataset = load_dataset(
469
  data_args.dataset_repo_or_path,
470
  data_files=data_files,
471
  streaming=data_args.streaming,
 
472
  )
473
 
474
  # Set up wandb run
@@ -477,56 +458,66 @@ def main():
477
  project="dalle-mini",
478
  job_type="Seq2Seq",
479
  config=parser.parse_args(),
480
- save_code=True,
481
  )
482
 
483
- if model_args.from_checkpoint is not None:
484
- artifact = wandb.run.use_artifact(model_args.from_checkpoint)
485
  artifact_dir = artifact.download()
 
 
486
  model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
487
 
488
  # load tokenizer
489
  tokenizer = AutoTokenizer.from_pretrained(
490
  artifact_dir,
491
- use_fast=model_args.use_fast_tokenizer,
492
  )
493
 
494
  else:
495
  # Set up our new model config
 
496
  config = BartConfig.from_pretrained(model_args.model_name_or_path)
497
- config.tie_word_embeddings = False
498
- config.decoder_start_token_id = BOS_TOKEN_ID # for first token
499
- config.bos_token_id = (
500
- BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
501
- )
502
- config.pos_token_id = (
503
- BOS_TOKEN_ID # should not be needed (as we generate until max_length)
504
- )
505
- config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
506
  config.forced_bos_token_id = None # we don't need this token
507
  config.forced_eos_token_id = None # we don't need this token
508
- config.force_bos_token_to_be_generated = (
509
- False # otherwise it sets bos_token_id at loading
510
- )
511
- config.min_length = data_args.max_target_length
512
- config.max_length = data_args.max_target_length
 
 
 
 
 
 
 
 
513
 
514
  # Create a custom model and initialize it randomly
515
  model = CustomFlaxBartForConditionalGeneration(
516
- config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
517
  )
518
 
519
  # Load tokenizer
520
- tokenizer = AutoTokenizer.from_pretrained(
521
- model_args.model_name_or_path,
522
- use_fast=model_args.use_fast_tokenizer,
523
- )
 
 
 
 
 
524
 
525
  print(f"TPUs: {jax.device_count()}")
526
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
527
 
528
- prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
529
-
530
  # Preprocessing the datasets.
531
  # We need to tokenize inputs and targets.
532
 
@@ -543,7 +534,7 @@ def main():
543
  shifted_input_ids[:, 0] = decoder_start_token_id
544
  return shifted_input_ids
545
 
546
- text_normalizer = TextNormalizer() if data_args.normalize_text else None
547
 
548
  def normalize_text(example):
549
  example[text_column] = text_normalizer(example[text_column])
@@ -551,7 +542,6 @@ def main():
551
 
552
  def preprocess_function(examples):
553
  inputs = examples[text_column]
554
- inputs = [prefix + inp for inp in inputs] if prefix else inputs
555
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
556
  model_inputs = tokenizer(
557
  inputs,
@@ -589,8 +579,15 @@ def main():
589
  else train_dataset.select(range(data_args.max_train_samples))
590
  )
591
  if data_args.streaming:
592
- train_dataset = train_dataset.shuffle(1000, training_args.seed)
593
- if data_args.normalize_text:
 
 
 
 
 
 
 
594
  train_dataset = (
595
  train_dataset.map(normalize_text)
596
  if data_args.streaming
@@ -627,7 +624,7 @@ def main():
627
  if data_args.streaming
628
  else eval_dataset.select(range(data_args.max_train_samples))
629
  )
630
- if data_args.normalize_text:
631
  eval_dataset = (
632
  eval_dataset.map(normalize_text)
633
  if data_args.streaming
@@ -655,7 +652,7 @@ def main():
655
  )
656
 
657
  # Initialize our training
658
- rng = jax.random.PRNGKey(training_args.seed)
659
  rng, dropout_rng = jax.random.split(rng)
660
 
661
  # Store some constant
@@ -665,35 +662,29 @@ def main():
665
  )
666
  batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
667
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
 
668
  if data_args.streaming:
669
- len_train_dataset = data_args.len_train
670
- if (
671
- data_args.max_train_samples is not None
672
- and data_args.max_train_samples < len_train_dataset
673
- ):
674
  len_train_dataset = data_args.max_train_samples
675
-
676
- len_eval_dataset = data_args.len_eval
677
- if (
678
- data_args.max_eval_samples is not None
679
- and data_args.max_eval_samples < len_eval_dataset
680
- ):
681
  len_eval_dataset = data_args.max_eval_samples
682
  else:
683
  len_train_dataset = len(train_dataset)
684
  len_eval_dataset = len(eval_dataset)
685
- steps_per_epoch = len_train_dataset // train_batch_size
686
- total_steps = steps_per_epoch * num_epochs
687
- total_optimization_steps = (len_train_dataset // batch_size_per_update) * num_epochs
 
 
 
688
 
689
  # Create learning rate schedule
690
  learning_rate_fn = create_learning_rate_fn(
691
- len_train_dataset,
692
- train_batch_size,
693
- training_args.num_train_epochs,
694
  training_args.warmup_steps,
695
  training_args.learning_rate,
696
- data_args.no_decay,
 
697
  )
698
 
699
  # We use Optax's "masking" functionality to not apply weight decay
@@ -701,8 +692,6 @@ def main():
701
  # mask boolean with the same structure as the parameters.
702
  # The mask is True for parameters that should be decayed.
703
  # Note that this mask is specifically adapted for FlaxBart.
704
- # For FlaxT5, one should correct the layer norm parameter naming
705
- # accordingly - see `run_t5_mlm_flax.py` e.g.
706
  def decay_mask_fn(params):
707
  flat_params = traverse_util.flatten_dict(params)
708
  layer_norm_params = [
@@ -725,6 +714,9 @@ def main():
725
  # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
726
  optimizer = optax.adafactor(
727
  learning_rate=learning_rate_fn,
 
 
 
728
  )
729
  else:
730
  optimizer = optax.adamw(
@@ -749,11 +741,10 @@ def main():
749
  tx=optimizer,
750
  dropout_rng=dropout_rng,
751
  )
752
- if model_args.from_checkpoint is not None:
753
- # restore optimizer state and step
 
754
  state = state.restore_state(artifact_dir)
755
- # TODO: number of remaining training epochs/steps and dataloader state need to be adjusted
756
- # TODO: optimizer may use a different step for learning rate, we should serialize/restore entire state
757
 
758
  # label smoothed cross entropy
759
  def loss_fn(logits, labels):
@@ -762,7 +753,7 @@ def main():
762
  return loss
763
 
764
  # Define gradient update step fn
765
- def train_step(state, batch):
766
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
767
 
768
  def compute_loss(params, batch):
@@ -776,14 +767,20 @@ def main():
776
  grad_fn = jax.value_and_grad(compute_loss)
777
  loss, grads = grad_fn(state.params, batch)
778
  grads = jax.lax.pmean(grads, "batch")
779
- state = state.apply_gradients(grads=grads)
 
 
 
 
 
780
 
781
  metrics = {
782
  "loss": loss,
783
  "learning_rate": learning_rate_fn(state.step),
784
  }
785
  metrics = jax.lax.pmean(metrics, axis_name="batch")
786
- return state.replace(dropout_rng=new_dropout_rng), metrics
 
787
 
788
  # Define eval fn
789
  def eval_step(params, batch):
@@ -800,10 +797,6 @@ def main():
800
  p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
801
  p_eval_step = jax.pmap(eval_step, "batch")
802
 
803
- # Replicate the train state on each device
804
- del model._params
805
- state = state.replicate()
806
-
807
  logger.info("***** Running training *****")
808
  logger.info(f" Num examples = {len_train_dataset}")
809
  logger.info(f" Num Epochs = {num_epochs}")
@@ -813,13 +806,12 @@ def main():
813
  logger.info(
814
  f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
815
  )
816
- logger.info(f" Total global steps = {total_steps}")
817
- logger.info(f" Total optimization steps = {total_optimization_steps}")
818
-
819
- epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
820
 
821
  # set default x-axis as 'train/step'
822
- wandb_log({}, step=unreplicate(state.step))
823
  wandb.define_metric("*", step_metric="train/step")
824
 
825
  # add interesting config parameters
@@ -828,11 +820,12 @@ def main():
828
  "len_train": len_train_dataset,
829
  "len_eval": len_eval_dataset,
830
  "batch_size_per_update": batch_size_per_update,
831
- "total_steps": total_steps,
832
- "total_optimization_steps": total_optimization_steps,
833
  }
834
  )
835
 
 
 
 
836
  def run_evaluation():
837
  # ======================== Evaluating ==============================
838
  eval_metrics = []
@@ -840,8 +833,12 @@ def main():
840
  if data_args.streaming:
841
  eval_loader = data_loader_streaming(eval_dataset, eval_batch_size)
842
  else:
843
- eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
844
- eval_steps = len_eval_dataset // eval_batch_size
 
 
 
 
845
  for batch in tqdm(
846
  eval_loader,
847
  desc="Evaluating...",
@@ -867,10 +864,9 @@ def main():
867
 
868
  return eval_metrics
869
 
870
- def run_save_model(state, step, epoch, eval_metrics=None):
871
  if jax.process_index() == 0:
872
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
873
-
874
  # save model locally
875
  model.save_pretrained(
876
  training_args.output_dir,
@@ -881,24 +877,30 @@ def main():
881
  tokenizer.save_pretrained(training_args.output_dir)
882
 
883
  # save state
884
- # TODO: maybe we should just save the full state object without params
885
- state = unreplicate(state)
886
  with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
887
- f.write(to_bytes(state.opt_state))
 
 
 
 
888
  with (Path(training_args.output_dir) / "training_state.json").open(
889
  "w"
890
  ) as f:
891
- json.dump({"step": state.step.item()}, f)
 
 
 
892
 
893
  # save to W&B
894
- if data_args.log_model:
895
  # save some space
896
  c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
897
- c.cleanup(wandb.util.from_human_size("5GB"))
898
 
899
- metadata = {"step": step, "epoch": epoch}
900
  if eval_metrics is not None:
901
- metadata["eval/loss"] = eval_metrics["loss"]
902
  artifact = wandb.Artifact(
903
  name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
904
  )
@@ -932,24 +934,26 @@ def main():
932
  training_args.output_dir,
933
  params=params,
934
  push_to_hub=training_args.push_to_hub,
935
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
936
  temp_dir=True, # avoid issues with being in a repository
937
  )
938
 
 
 
 
 
939
  for epoch in epochs:
 
940
  # ======================== Training ================================
941
- step = unreplicate(state.step)
942
- wandb_log({"train/epoch": epoch}, step=step)
943
 
944
  # Generate an epoch by shuffling sampling indices from the train dataset
945
  if data_args.streaming:
946
- train_dataset.set_epoch(epoch)
947
  train_loader = data_loader_streaming(train_dataset, train_batch_size)
948
  else:
949
- rng, input_rng = jax.random.split(rng)
950
- train_loader = data_loader(
951
- input_rng, train_dataset, train_batch_size, shuffle=True
952
- )
953
  # train
954
  for batch in tqdm(
955
  train_loader,
@@ -958,32 +962,49 @@ def main():
958
  leave=False,
959
  total=steps_per_epoch,
960
  ):
961
- state, train_metric = p_train_step(state, batch)
 
 
 
 
 
 
 
 
 
962
  step = unreplicate(state.step)
963
 
964
- if step % data_args.log_interval == 0 and jax.process_index() == 0:
965
  # log metrics
966
  wandb_log(unreplicate(train_metric), step=step, prefix="train")
967
-
 
 
 
 
 
 
 
968
  if training_args.eval_steps and step % training_args.eval_steps == 0:
969
- run_evaluation()
970
 
971
- if step % data_args.save_model_steps == 0:
972
- run_save_model(state, step, epoch)
973
 
974
  # log final train metrics
975
- wandb_log(unreplicate(train_metric), step=step, prefix="train")
 
 
976
 
977
- train_metric = unreplicate(train_metric)
978
- epochs.write(
979
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
980
- )
981
 
982
  # Final evaluation
983
  eval_metrics = run_evaluation()
984
 
985
  # save checkpoint after each epoch
986
- run_save_model(state, state.step, epoch, eval_metrics)
987
 
988
 
989
  if __name__ == "__main__":
 
17
  Fine-tuning the library models for seq2seq, text to image.
18
  Script adapted from run_summarization_flax.py
19
  """
 
20
 
21
  import os
22
+ import logging
23
  import sys
24
+ import time
25
  from dataclasses import dataclass, field
26
  from pathlib import Path
27
  from typing import Callable, Optional
 
38
  import transformers
39
  from flax import jax_utils, traverse_util
40
  from flax.serialization import from_bytes, to_bytes
 
41
  from flax.jax_utils import unreplicate
42
  from flax.training import train_state
43
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
44
  from transformers import (
45
  AutoTokenizer,
 
46
  HfArgumentParser,
 
47
  )
48
+ from transformers.models.bart.modeling_flax_bart import BartConfig
49
 
50
  import wandb
51
 
52
  from dalle_mini.text import TextNormalizer
53
+ from dalle_mini.model import CustomFlaxBartForConditionalGeneration
54
 
55
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
56
 
57
 
58
  @dataclass
 
62
  """
63
 
64
  model_name_or_path: Optional[str] = field(
65
+ default=None,
66
  metadata={
67
  "help": "The model checkpoint for weights initialization."
68
  "Don't set if you want to train a model from scratch."
69
  },
70
  )
71
+ image_vocab_size: Optional[int] = field(
72
  default=None,
73
+ metadata={"help": "Vocab size of image encoder"},
 
 
74
  )
75
+ image_length: Optional[int] = field(
76
+ default=None,
77
+ metadata={"help": "Number of tokens per image"},
78
+ )
79
+ tokenizer_name: Optional[str] = field(
80
+ default=None,
81
  metadata={
82
+ "help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
83
  },
84
  )
85
+ normalize_text: bool = field(
86
+ default=False,
87
+ metadata={"help": "Whether to normalize text or not."},
88
+ )
89
  dtype: Optional[str] = field(
90
  default="float32",
91
  metadata={
92
  "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
93
  },
94
  )
 
 
 
 
 
 
95
 
96
 
97
  @dataclass
 
129
  default=False,
130
  metadata={"help": "Whether to stream the dataset."},
131
  )
132
+ use_auth_token: bool = field(
133
+ default=False,
134
+ metadata={
135
+ "help": "Whether to use the authentication token for private datasets."
136
+ },
 
 
137
  )
138
  max_source_length: Optional[int] = field(
139
  default=128,
 
142
  "than this will be truncated, sequences shorter will be padded."
143
  },
144
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  max_train_samples: Optional[int] = field(
146
  default=None,
147
  metadata={
 
156
  "value if set."
157
  },
158
  )
 
 
 
 
159
  preprocessing_num_workers: Optional[int] = field(
 
 
 
 
160
  default=None,
161
  metadata={
162
+ "help": "The number of processes to use for the preprocessing. Not used in streaming mode."
163
  },
164
  )
165
  overwrite_cache: bool = field(
166
  default=False,
167
+ metadata={
168
+ "help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
169
+ },
170
+ )
171
+
172
+ def __post_init__(self):
173
+ if self.dataset_repo_or_path is None:
174
+ raise ValueError("Need a dataset repository or path.")
175
+
176
+
177
+ @dataclass
178
+ class TrainingArguments:
179
+ """
180
+ Arguments pertaining to training parameters.
181
+ """
182
+
183
+ output_dir: str = field(
184
+ metadata={
185
+ "help": "The output directory where the model predictions and checkpoints will be written."
186
+ },
187
+ )
188
+ overwrite_output_dir: bool = field(
189
+ default=False,
190
+ metadata={
191
+ "help": (
192
+ "Overwrite the content of the output directory. "
193
+ "Use this to continue training if output_dir points to a checkpoint directory."
194
+ )
195
+ },
196
+ )
197
+
198
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
199
+ do_eval: bool = field(
200
+ default=False, metadata={"help": "Whether to run eval on the dev set."}
201
+ )
202
+
203
+ per_device_train_batch_size: int = field(
204
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
205
+ )
206
+ per_device_eval_batch_size: int = field(
207
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
208
+ )
209
+
210
+ gradient_accumulation_steps: int = field(
211
+ default=1,
212
+ metadata={
213
+ "help": "Number of updates steps to accumulate before performing a backward/update pass."
214
+ },
215
+ )
216
+
217
+ learning_rate: float = field(
218
+ default=5e-5, metadata={"help": "The initial learning rate."}
219
+ )
220
+ adafactor: bool = field(
221
+ default=False,
222
+ metadata={"help": "Whether or not to replace AdamW by Adafactor."},
223
+ )
224
+ weight_decay: float = field(
225
+ default=None, metadata={"help": "Weight decay if we apply some."}
226
+ )
227
+ adam_beta1: float = field(
228
+ default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}
229
+ )
230
+ adam_beta2: float = field(
231
+ default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}
232
+ )
233
+ adam_epsilon: float = field(
234
+ default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
235
+ )
236
+ max_grad_norm: float = field(
237
+ default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
238
+ )
239
+ use_decay: bool = field(
240
+ default=False,
241
+ metadata={"help": "Whether to use decay in the learning rate scheduler."},
242
+ )
243
+
244
+ num_train_epochs: float = field(
245
+ default=3.0, metadata={"help": "Total number of training epochs to perform."}
246
+ )
247
+ warmup_steps: int = field(
248
+ default=0, metadata={"help": "Linear warmup over warmup_steps."}
249
+ )
250
+
251
+ logging_steps: int = field(
252
+ default=40, metadata={"help": "Log every X updates steps."}
253
+ )
254
+ eval_steps: int = field(
255
+ default=400, metadata={"help": "Run an evaluation every X steps."}
256
  )
257
+ save_steps: int = field(
258
+ default=4000, metadata={"help": "Save checkpoint every X updates steps."}
 
259
  )
260
  log_model: bool = field(
261
  default=False,
262
+ metadata={"help": "Log model to wandb at `save_steps` frequency."},
263
  )
264
+
265
+ seed_model: int = field(
266
+ default=42,
267
+ metadata={
268
+ "help": "Random seed for the model that will be set at the beginning of training."
269
+ },
270
+ )
271
+ # default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
272
+ seed_dataset: int = field(
273
+ default=None,
274
  metadata={
275
+ "help": "Random seed for the dataset that will be set at the beginning of training."
276
  },
277
  )
278
 
279
+ push_to_hub: bool = field(
280
+ default=False,
281
+ metadata={
282
+ "help": "Whether or not to upload the trained model to the model hub after training."
283
+ },
284
+ )
285
+
286
+ resume_from_wandb_checkpoint: Optional[str] = field(
287
+ default=None,
288
+ metadata={"help": "The reference to a wandb artifact for resuming training."},
289
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
 
292
  class TrainState(train_state.TrainState):
293
  dropout_rng: jnp.ndarray = None
294
+ epoch: int = 0
295
+ train_time: float = 0.0 # total time the model trained
296
+ train_samples: int = 0 # number of samples seen
297
 
298
  def replicate(self):
299
  return jax_utils.replicate(self).replace(
 
305
  with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
306
  new_opt_state = from_bytes(self.opt_state, f.read())
307
 
308
+ # restore other parameters
309
  with (Path(artifact_dir) / "training_state.json").open("r") as f:
310
  training_state = json.load(f)
 
311
 
312
  # replace state
313
+ return self.replace(
314
+ opt_state=new_opt_state,
315
+ step=training_state["step"],
316
+ train_time=training_state["train_time"],
317
+ train_samples=training_state["train_samples"],
 
 
 
 
 
 
318
  )
319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
  def data_loader(
322
+ dataset: Dataset,
323
+ batch_size: int,
324
+ rng: jax.random.PRNGKey = None,
325
  ):
326
  """
327
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
 
329
  """
330
  steps_per_epoch = len(dataset) // batch_size
331
 
332
+ if rng is not None:
333
  batch_idx = jax.random.permutation(rng, len(dataset))
334
  else:
335
  batch_idx = jnp.arange(len(dataset))
 
358
 
359
 
360
  def create_learning_rate_fn(
 
 
 
361
  num_warmup_steps: int,
362
  learning_rate: float,
363
+ use_decay: bool,
364
+ num_train_steps: int = None, # used only with `use_decay`, typically train_size // batch_size * num_epochs
365
  ) -> Callable[[int], jnp.array]:
366
  """Returns a linear warmup, linear_decay learning rate function."""
367
+ if use_decay:
368
+ assert (
369
+ num_train_steps is not None
370
+ ), "Learning rate with decay requires number of training steps"
371
  warmup_fn = optax.linear_schedule(
372
  init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
373
  )
374
+ if not use_decay:
375
  return warmup_fn
376
  decay_fn = optax.linear_schedule(
377
  init_value=learning_rate,
 
395
 
396
 
397
  def main():
398
+ # See all possible arguments by passing the --help flag to this script.
 
 
 
399
  parser = HfArgumentParser(
400
  (ModelArguments, DataTrainingArguments, TrainingArguments)
401
  )
 
420
  )
421
 
422
  # Make one log on every process with the configuration for debugging.
423
+ logging.basicConfig(
424
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
425
  datefmt="%m/%d/%Y %H:%M:%S",
426
+ level=logging.INFO,
427
  )
428
  # Setup logging, we only want one process per machine to log things on the screen.
429
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
430
  if jax.process_index() == 0:
431
  datasets.utils.logging.set_verbosity_warning()
432
  transformers.utils.logging.set_verbosity_info()
 
437
  # Set the verbosity to info of the Transformers logger (on main process only):
438
  logger.info(f"Training/evaluation parameters {training_args}")
439
 
440
+ # Load dataset
441
+ if data_args.train_file is not None or data_args.validation_file is not None:
442
+ data_files = {
443
+ "train": data_args.train_file,
444
+ "validation": data_args.validation_file,
445
+ }
446
+ else:
447
+ data_files = None
448
  dataset = load_dataset(
449
  data_args.dataset_repo_or_path,
450
  data_files=data_files,
451
  streaming=data_args.streaming,
452
+ use_auth_token=data_args.use_auth_token,
453
  )
454
 
455
  # Set up wandb run
 
458
  project="dalle-mini",
459
  job_type="Seq2Seq",
460
  config=parser.parse_args(),
 
461
  )
462
 
463
+ if training_args.resume_from_wandb_checkpoint is not None:
464
+ artifact = wandb.run.use_artifact(training_args.resume_from_wandb_checkpoint)
465
  artifact_dir = artifact.download()
466
+
467
+ # load model
468
  model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
469
 
470
  # load tokenizer
471
  tokenizer = AutoTokenizer.from_pretrained(
472
  artifact_dir,
473
+ use_fast=True,
474
  )
475
 
476
  else:
477
  # Set up our new model config
478
+ # TODO: simplify with custom config class
479
  config = BartConfig.from_pretrained(model_args.model_name_or_path)
480
+ config.image_vocab_size = model_args.image_vocab_size
481
+ config.image_length = model_args.image_length
482
+ # we append decoder bos to image vocab
483
+ config.decoder_start_token_id = config.image_vocab_size
484
+ # ensure we don't generate bos (in addition to decoder start token)
485
+ config.force_bos_token_to_be_generated = False
 
 
 
486
  config.forced_bos_token_id = None # we don't need this token
487
  config.forced_eos_token_id = None # we don't need this token
488
+
489
+ config.tie_word_embeddings = False
490
+ config.min_length = model_args.image_length + 1
491
+ config.max_length = model_args.image_length + 1
492
+
493
+ # below tokens need to be set to avoid error during generation (converted to jnp.array)
494
+ # they are not expected to be used and are set to unreachable token id
495
+ config.bos_token_id = config.image_vocab_size + 1
496
+ config.pos_token_id = config.image_vocab_size + 1
497
+ config.eos_token_id = config.image_vocab_size + 1
498
+
499
+ # save whether we normalize the text
500
+ config.normalize_text = model_args.normalize_text
501
 
502
  # Create a custom model and initialize it randomly
503
  model = CustomFlaxBartForConditionalGeneration(
504
+ config, seed=training_args.seed_model, dtype=getattr(jnp, model_args.dtype)
505
  )
506
 
507
  # Load tokenizer
508
+ if model_args.tokenizer_name is not None:
509
+ tokenizer = AutoTokenizer.from_pretrained(
510
+ model_args.tokenizer_name, use_fast=True
511
+ )
512
+ else:
513
+ tokenizer = AutoTokenizer.from_pretrained(
514
+ model_args.model_name_or_path,
515
+ use_fast=True,
516
+ )
517
 
518
  print(f"TPUs: {jax.device_count()}")
519
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
520
 
 
 
521
  # Preprocessing the datasets.
522
  # We need to tokenize inputs and targets.
523
 
 
534
  shifted_input_ids[:, 0] = decoder_start_token_id
535
  return shifted_input_ids
536
 
537
+ text_normalizer = TextNormalizer() if model.config.normalize_text else None
538
 
539
  def normalize_text(example):
540
  example[text_column] = text_normalizer(example[text_column])
 
542
 
543
  def preprocess_function(examples):
544
  inputs = examples[text_column]
 
545
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
546
  model_inputs = tokenizer(
547
  inputs,
 
579
  else train_dataset.select(range(data_args.max_train_samples))
580
  )
581
  if data_args.streaming:
582
+ train_dataset = train_dataset.shuffle(1000, training_args.seed_dataset)
583
+ else:
584
+ seed_dataset = (
585
+ training_args.seed_dataset
586
+ if training_args.seed_dataset is not None
587
+ else np.random.get_state()[1][0]
588
+ )
589
+ rng_dataset = jax.random.PRNGKey(seed_dataset)
590
+ if model.config.normalize_text:
591
  train_dataset = (
592
  train_dataset.map(normalize_text)
593
  if data_args.streaming
 
624
  if data_args.streaming
625
  else eval_dataset.select(range(data_args.max_train_samples))
626
  )
627
+ if model.config.normalize_text:
628
  eval_dataset = (
629
  eval_dataset.map(normalize_text)
630
  if data_args.streaming
 
652
  )
653
 
654
  # Initialize our training
655
+ rng = jax.random.PRNGKey(training_args.seed_model)
656
  rng, dropout_rng = jax.random.split(rng)
657
 
658
  # Store some constant
 
662
  )
663
  batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
664
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
665
+ len_train_dataset, len_eval_dataset = None, None
666
  if data_args.streaming:
667
+ # we don't know the length, let's just assume max_samples if defined
668
+ if data_args.max_train_samples is not None:
 
 
 
669
  len_train_dataset = data_args.max_train_samples
670
+ if data_args.max_eval_samples is not None:
 
 
 
 
 
671
  len_eval_dataset = data_args.max_eval_samples
672
  else:
673
  len_train_dataset = len(train_dataset)
674
  len_eval_dataset = len(eval_dataset)
675
+ steps_per_epoch = (
676
+ len_train_dataset // train_batch_size if len_train_dataset is not None else None
677
+ )
678
+ num_train_steps = (
679
+ steps_per_epoch * num_epochs if steps_per_epoch is not None else None
680
+ )
681
 
682
  # Create learning rate schedule
683
  learning_rate_fn = create_learning_rate_fn(
 
 
 
684
  training_args.warmup_steps,
685
  training_args.learning_rate,
686
+ training_args.use_decay,
687
+ num_train_steps,
688
  )
689
 
690
  # We use Optax's "masking" functionality to not apply weight decay
 
692
  # mask boolean with the same structure as the parameters.
693
  # The mask is True for parameters that should be decayed.
694
  # Note that this mask is specifically adapted for FlaxBart.
 
 
695
  def decay_mask_fn(params):
696
  flat_params = traverse_util.flatten_dict(params)
697
  layer_norm_params = [
 
714
  # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
715
  optimizer = optax.adafactor(
716
  learning_rate=learning_rate_fn,
717
+ weight_decay_rate=training_args.weight_decay,
718
+ weight_decay_mask=decay_mask_fn,
719
+ clipping_threshold=training_args.max_grad_norm,
720
  )
721
  else:
722
  optimizer = optax.adamw(
 
741
  tx=optimizer,
742
  dropout_rng=dropout_rng,
743
  )
744
+ if training_args.resume_from_wandb_checkpoint is not None:
745
+ # restore optimizer state and other parameters
746
+ # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
747
  state = state.restore_state(artifact_dir)
 
 
748
 
749
  # label smoothed cross entropy
750
  def loss_fn(logits, labels):
 
753
  return loss
754
 
755
  # Define gradient update step fn
756
+ def train_step(state, batch, delta_time):
757
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
758
 
759
  def compute_loss(params, batch):
 
767
  grad_fn = jax.value_and_grad(compute_loss)
768
  loss, grads = grad_fn(state.params, batch)
769
  grads = jax.lax.pmean(grads, "batch")
770
+ state = state.apply_gradients(
771
+ grads=grads,
772
+ dropout_rng=new_dropout_rng,
773
+ train_time=state.train_time + delta_time,
774
+ train_samples=state.train_samples + train_batch_size,
775
+ )
776
 
777
  metrics = {
778
  "loss": loss,
779
  "learning_rate": learning_rate_fn(state.step),
780
  }
781
  metrics = jax.lax.pmean(metrics, axis_name="batch")
782
+
783
+ return state, metrics
784
 
785
  # Define eval fn
786
  def eval_step(params, batch):
 
797
  p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
798
  p_eval_step = jax.pmap(eval_step, "batch")
799
 
 
 
 
 
800
  logger.info("***** Running training *****")
801
  logger.info(f" Num examples = {len_train_dataset}")
802
  logger.info(f" Num Epochs = {num_epochs}")
 
806
  logger.info(
807
  f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
808
  )
809
+ epochs = tqdm(
810
+ range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
811
+ )
 
812
 
813
  # set default x-axis as 'train/step'
814
+ wandb_log({}, step=state.step)
815
  wandb.define_metric("*", step_metric="train/step")
816
 
817
  # add interesting config parameters
 
820
  "len_train": len_train_dataset,
821
  "len_eval": len_eval_dataset,
822
  "batch_size_per_update": batch_size_per_update,
 
 
823
  }
824
  )
825
 
826
+ # replicate state on each device
827
+ state = state.replicate()
828
+
829
  def run_evaluation():
830
  # ======================== Evaluating ==============================
831
  eval_metrics = []
 
833
  if data_args.streaming:
834
  eval_loader = data_loader_streaming(eval_dataset, eval_batch_size)
835
  else:
836
+ eval_loader = data_loader(eval_dataset, eval_batch_size)
837
+ eval_steps = (
838
+ len_eval_dataset // eval_batch_size
839
+ if len_eval_dataset is not None
840
+ else None
841
+ )
842
  for batch in tqdm(
843
  eval_loader,
844
  desc="Evaluating...",
 
864
 
865
  return eval_metrics
866
 
867
+ def run_save_model(state, eval_metrics=None):
868
  if jax.process_index() == 0:
869
+ params = jax.device_get(unreplicate(state.params))
 
870
  # save model locally
871
  model.save_pretrained(
872
  training_args.output_dir,
 
877
  tokenizer.save_pretrained(training_args.output_dir)
878
 
879
  # save state
880
+ opt_state = unreplicate(state.opt_state)
 
881
  with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
882
+ f.write(to_bytes(opt_state))
883
+ state_dict = {
884
+ k: jax.device_get(unreplicate(getattr(state, k))).item()
885
+ for k in ["step", "epoch", "train_time", "train_samples"]
886
+ }
887
  with (Path(training_args.output_dir) / "training_state.json").open(
888
  "w"
889
  ) as f:
890
+ json.dump(
891
+ state_dict,
892
+ f,
893
+ )
894
 
895
  # save to W&B
896
+ if training_args.log_model:
897
  # save some space
898
  c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
899
+ c.cleanup(wandb.util.from_human_size("10GB"))
900
 
901
+ metadata = dict(state_dict)
902
  if eval_metrics is not None:
903
+ metadata["eval"] = eval_metrics
904
  artifact = wandb.Artifact(
905
  name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
906
  )
 
934
  training_args.output_dir,
935
  params=params,
936
  push_to_hub=training_args.push_to_hub,
937
+ commit_message=f"Saving weights and logs at step {unreplicate(state.step)+1}",
938
  temp_dir=True, # avoid issues with being in a repository
939
  )
940
 
941
+ # init variables
942
+ last_time = time.perf_counter()
943
+ train_metric = None
944
+
945
  for epoch in epochs:
946
+ state.replace(epoch=jax_utils.replicate(epoch))
947
  # ======================== Training ================================
948
+ wandb_log({"train/epoch": epoch}, step=unreplicate(state.step))
 
949
 
950
  # Generate an epoch by shuffling sampling indices from the train dataset
951
  if data_args.streaming:
952
+ train_dataset.set_epoch(epoch) # shuffle dataset
953
  train_loader = data_loader_streaming(train_dataset, train_batch_size)
954
  else:
955
+ rng_dataset, input_rng = jax.random.split(rng_dataset)
956
+ train_loader = data_loader(train_dataset, train_batch_size, rng=input_rng)
 
 
957
  # train
958
  for batch in tqdm(
959
  train_loader,
 
962
  leave=False,
963
  total=steps_per_epoch,
964
  ):
965
+
966
+ # calculate delta time (we have a lag of one step but it's ok)
967
+ new_time = time.perf_counter()
968
+ delta_time = new_time - last_time
969
+ last_time = new_time
970
+
971
+ # train step
972
+ state, train_metric = p_train_step(
973
+ state, batch, jax_utils.replicate(delta_time)
974
+ )
975
  step = unreplicate(state.step)
976
 
977
+ if step % training_args.logging_steps == 0 and jax.process_index() == 0:
978
  # log metrics
979
  wandb_log(unreplicate(train_metric), step=step, prefix="train")
980
+ # log state parameters
981
+ state_dict = {
982
+ k.split("_")[-1]: unreplicate(getattr(state, k))
983
+ for k in ["epoch", "train_time", "train_samples"]
984
+ }
985
+ wandb_log(state_dict, step=step, prefix="train")
986
+
987
+ eval_metrics = None
988
  if training_args.eval_steps and step % training_args.eval_steps == 0:
989
+ eval_metrics = run_evaluation()
990
 
991
+ if step % training_args.save_steps == 0:
992
+ run_save_model(state, eval_metrics)
993
 
994
  # log final train metrics
995
+ if train_metric is not None:
996
+ train_metric = unreplicate(train_metric)
997
+ wandb_log(train_metric, step=step, prefix="train")
998
 
999
+ epochs.write(
1000
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
1001
+ )
 
1002
 
1003
  # Final evaluation
1004
  eval_metrics = run_evaluation()
1005
 
1006
  # save checkpoint after each epoch
1007
+ run_save_model(state, eval_metrics)
1008
 
1009
 
1010
  if __name__ == "__main__":
setup.cfg CHANGED
@@ -12,5 +12,7 @@ project_urls =
12
  packages = find:
13
  install_requires =
14
  transformers
 
 
15
  jax
16
  flax
 
12
  packages = find:
13
  install_requires =
14
  transformers
15
+ unidecode
16
+ ftfy
17
  jax
18
  flax