boris commited on
Commit
7f962d6
1 Parent(s): 1f05876

feat: separate model definition

Browse files

Former-commit-id: c049a9387bdbadc71f5ee9f17d42aa25d6233ebd

Files changed (2) hide show
  1. app/app_gradio.py +10 -56
  2. dalle_mini/model.py +66 -0
app/app_gradio.py CHANGED
@@ -12,74 +12,28 @@ import flax.linen as nn
12
  from flax.training.common_utils import shard
13
  from flax.jax_utils import replicate, unreplicate
14
 
15
- from transformers.models.bart.modeling_flax_bart import *
16
  from transformers import BartTokenizer, FlaxBartForConditionalGeneration
17
 
18
-
19
- import requests
20
  from PIL import Image
21
  import numpy as np
22
  import matplotlib.pyplot as plt
23
 
24
 
25
  from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
 
26
 
27
  import gradio as gr
28
 
29
 
30
- # TODO: set those args in a config file
31
- OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
32
- OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
33
- BOS_TOKEN_ID = 16384
34
- BASE_MODEL = 'flax-community/dalle-mini'
35
-
36
- class CustomFlaxBartModule(FlaxBartModule):
37
- def setup(self):
38
- # we keep shared to easily load pre-trained weights
39
- self.shared = nn.Embed(
40
- self.config.vocab_size,
41
- self.config.d_model,
42
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
43
- dtype=self.dtype,
44
- )
45
- # a separate embedding is used for the decoder
46
- self.decoder_embed = nn.Embed(
47
- OUTPUT_VOCAB_SIZE,
48
- self.config.d_model,
49
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
50
- dtype=self.dtype,
51
- )
52
- self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
53
-
54
- # the decoder has a different config
55
- decoder_config = BartConfig(self.config.to_dict())
56
- decoder_config.max_position_embeddings = OUTPUT_LENGTH
57
- decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
58
- self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
59
-
60
- class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
61
- def setup(self):
62
- self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
63
- self.lm_head = nn.Dense(
64
- OUTPUT_VOCAB_SIZE,
65
- use_bias=False,
66
- dtype=self.dtype,
67
- kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
68
- )
69
- self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
70
-
71
- class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
72
- module_class = CustomFlaxBartForConditionalGenerationModule
73
-
74
- # create our model
75
- # FIXME: Save tokenizer to hub so we can load from there
76
- tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
77
- model = CustomFlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)
78
- model.config.force_bos_token_to_be_generated = False
79
- model.config.forced_bos_token_id = None
80
- model.config.forced_eos_token_id = None
81
-
82
- vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
83
 
84
  def custom_to_pil(x):
85
  x = np.clip(x, 0., 1.)
 
12
  from flax.training.common_utils import shard
13
  from flax.jax_utils import replicate, unreplicate
14
 
 
15
  from transformers import BartTokenizer, FlaxBartForConditionalGeneration
16
 
 
 
17
  from PIL import Image
18
  import numpy as np
19
  import matplotlib.pyplot as plt
20
 
21
 
22
  from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
23
+ from dalle_mini.model import CustomFlaxBartForConditionalGeneration
24
 
25
  import gradio as gr
26
 
27
 
28
+ DALLE_REPO = 'flax-community/dalle-mini'
29
+ DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'
30
+
31
+ VQGAN_REPO = 'flax-community/vqgan_f16_16384'
32
+ VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'
33
+
34
+ tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
35
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
36
+ vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def custom_to_pil(x):
39
  x = np.clip(x, 0., 1.)
dalle_mini/model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import jax
3
+ import flax.linen as nn
4
+
5
+ from transformers.models.bart.modeling_flax_bart import (
6
+ FlaxBartModule,
7
+ FlaxBartForConditionalGenerationModule,
8
+ FlaxBartForConditionalGeneration,
9
+ FlaxBartEncoder,
10
+ FlaxBartDecoder
11
+ )
12
+
13
+ from transformers import BartConfig
14
+
15
+
16
+ # Model hyperparameters, for convenience
17
+ OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
18
+ OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
19
+ BOS_TOKEN_ID = 16384
20
+ BASE_MODEL = 'facebook/bart-large-cnn' # we currently have issues with bart-large
21
+
22
+
23
+ class CustomFlaxBartModule(FlaxBartModule):
24
+ def setup(self):
25
+ # check config is valid, otherwise set default values
26
+ self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
27
+ self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH)
28
+
29
+ # we keep shared to easily load pre-trained weights
30
+ self.shared = nn.Embed(
31
+ self.config.vocab_size,
32
+ self.config.d_model,
33
+ embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
34
+ dtype=self.dtype,
35
+ )
36
+ # a separate embedding is used for the decoder
37
+ self.decoder_embed = nn.Embed(
38
+ self.config.vocab_size_output,
39
+ self.config.d_model,
40
+ embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
41
+ dtype=self.dtype,
42
+ )
43
+ self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
44
+
45
+ # the decoder has a different config
46
+ decoder_config = BartConfig(self.config.to_dict())
47
+ decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder
48
+ decoder_config.vocab_size = self.config.vocab_size_output
49
+ self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
50
+
51
+ class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
52
+ def setup(self):
53
+ # check config is valid, otherwise set default values
54
+ self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
55
+
56
+ self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
57
+ self.lm_head = nn.Dense(
58
+ self.config.vocab_size_output,
59
+ use_bias=False,
60
+ dtype=self.dtype,
61
+ kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
62
+ )
63
+ self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output))
64
+
65
+ class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
66
+ module_class = CustomFlaxBartForConditionalGenerationModule