boris commited on
Commit
92ccf4c
1 Parent(s): 0ca6514

feat(model): set default config for legacy models

Browse files
Files changed (1) hide show
  1. dalle_mini/model.py +5 -0
dalle_mini/model.py CHANGED
@@ -46,6 +46,11 @@ class CustomFlaxBartForConditionalGenerationModule(
46
  FlaxBartForConditionalGenerationModule
47
  ):
48
  def setup(self):
 
 
 
 
 
49
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
50
  self.lm_head = nn.Dense(
51
  self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
 
46
  FlaxBartForConditionalGenerationModule
47
  ):
48
  def setup(self):
49
+ # set default config
50
+ self.config.normalize_text = getattr(self.config, "normalize_text", False)
51
+ self.config.image_length = getattr(self.config, "image_length", 256)
52
+ self.config.image_vocab_size = getattr(self.config, "image_vocab_size", 16384)
53
+
54
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
55
  self.lm_head = nn.Dense(
56
  self.config.image_vocab_size + 1, # encoded image token space + 1 for bos