Pedro Cuenca commited on
Commit
7e48337
1 Parent(s): 2b2be9b

Tokenizer, config, model can be loaded from wandb.

Browse files
src/dalle_mini/model/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
  from .configuration import DalleBartConfig
2
  from .modeling import DalleBart
 
 
1
  from .configuration import DalleBartConfig
2
  from .modeling import DalleBart
3
+ from .tokenizer import DalleBartTokenizer
src/dalle_mini/model/configuration.py CHANGED
@@ -18,10 +18,12 @@ import warnings
18
  from transformers.configuration_utils import PretrainedConfig
19
  from transformers.utils import logging
20
 
 
 
21
  logger = logging.get_logger(__name__)
22
 
23
 
24
- class DalleBartConfig(PretrainedConfig):
25
  model_type = "dallebart"
26
  keys_to_ignore_at_inference = ["past_key_values"]
27
  attribute_map = {
 
18
  from transformers.configuration_utils import PretrainedConfig
19
  from transformers.utils import logging
20
 
21
+ from .wandb_pretrained import PretrainedFromWandbMixin
22
+
23
  logger = logging.get_logger(__name__)
24
 
25
 
26
+ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
27
  model_type = "dallebart"
28
  keys_to_ignore_at_inference = ["past_key_values"]
29
  attribute_map = {
src/dalle_mini/model/modeling.py CHANGED
@@ -15,14 +15,12 @@
15
  """ DalleBart model. """
16
 
17
  import math
18
- import os
19
  from functools import partial
20
  from typing import Optional, Tuple
21
 
22
  import flax.linen as nn
23
  import jax
24
  import jax.numpy as jnp
25
- import wandb
26
  from flax.core.frozen_dict import unfreeze
27
  from flax.linen import make_causal_mask
28
  from flax.traverse_util import flatten_dict
@@ -48,6 +46,7 @@ from transformers.models.bart.modeling_flax_bart import (
48
  from transformers.utils import logging
49
 
50
  from .configuration import DalleBartConfig
 
51
 
52
  logger = logging.get_logger(__name__)
53
 
@@ -421,7 +420,9 @@ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationMod
421
  )
422
 
423
 
424
- class DalleBart(FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration):
 
 
425
  """
426
  Edits:
427
  - renamed from FlaxBartForConditionalGeneration
@@ -563,24 +564,3 @@ class DalleBart(FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration):
563
  outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
564
 
565
  return outputs
566
-
567
- @classmethod
568
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
569
- """
570
- Initializes from a wandb artifact, or delegates loading to the superclass.
571
- """
572
- if ":" in pretrained_model_name_or_path and not os.path.isdir(
573
- pretrained_model_name_or_path
574
- ):
575
- # wandb artifact
576
- artifact = wandb.Api().artifact(pretrained_model_name_or_path)
577
-
578
- # we download everything, including opt_state, so we can resume training if needed
579
- # see also: #120
580
- pretrained_model_name_or_path = artifact.download()
581
-
582
- model = super(DalleBart, cls).from_pretrained(
583
- pretrained_model_name_or_path, *model_args, **kwargs
584
- )
585
- model.config.resolved_name_or_path = pretrained_model_name_or_path
586
- return model
 
15
  """ DalleBart model. """
16
 
17
  import math
 
18
  from functools import partial
19
  from typing import Optional, Tuple
20
 
21
  import flax.linen as nn
22
  import jax
23
  import jax.numpy as jnp
 
24
  from flax.core.frozen_dict import unfreeze
25
  from flax.linen import make_causal_mask
26
  from flax.traverse_util import flatten_dict
 
46
  from transformers.utils import logging
47
 
48
  from .configuration import DalleBartConfig
49
+ from .wandb_pretrained import PretrainedFromWandbMixin
50
 
51
  logger = logging.get_logger(__name__)
52
 
 
420
  )
421
 
422
 
423
+ class DalleBart(
424
+ PretrainedFromWandbMixin, FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration
425
+ ):
426
  """
427
  Edits:
428
  - renamed from FlaxBartForConditionalGeneration
 
564
  outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
565
 
566
  return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/dalle_mini/model/tokenizer.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ DalleBart tokenizer """
2
+ from transformers import BartTokenizer
3
+ from transformers.utils import logging
4
+
5
+ from .wandb_pretrained import PretrainedFromWandbMixin
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class DalleBartTokenizer(PretrainedFromWandbMixin, BartTokenizer):
11
+ pass
src/dalle_mini/model/wandb_pretrained.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+
4
+
5
+ class PretrainedFromWandbMixin:
6
+ @classmethod
7
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
8
+ """
9
+ Initializes from a wandb artifact, or delegates loading to the superclass.
10
+ """
11
+ if ":" in pretrained_model_name_or_path and not os.path.isdir(
12
+ pretrained_model_name_or_path
13
+ ):
14
+ # wandb artifact
15
+ artifact = wandb.Api().artifact(pretrained_model_name_or_path)
16
+ pretrained_model_name_or_path = artifact.download()
17
+
18
+ return super(PretrainedFromWandbMixin, cls).from_pretrained(
19
+ pretrained_model_name_or_path, *model_args, **kwargs
20
+ )