boris commited on
Commit
ad6ad64
1 Parent(s): a11eff5

feat: model config not hardcoded

Browse files

Former-commit-id: 8cc773f8dfaee95469a926d907c006873922e1c6

Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +12 -5
seq2seq/run_seq2seq_flax.py CHANGED
@@ -271,6 +271,10 @@ class TrainState(train_state.TrainState):
271
 
272
  class CustomFlaxBartModule(FlaxBartModule):
273
  def setup(self):
 
 
 
 
274
  # we keep shared to easily load pre-trained weights
275
  self.shared = nn.Embed(
276
  self.config.vocab_size,
@@ -280,7 +284,7 @@ class CustomFlaxBartModule(FlaxBartModule):
280
  )
281
  # a separate embedding is used for the decoder
282
  self.decoder_embed = nn.Embed(
283
- OUTPUT_VOCAB_SIZE,
284
  self.config.d_model,
285
  embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
286
  dtype=self.dtype,
@@ -289,20 +293,23 @@ class CustomFlaxBartModule(FlaxBartModule):
289
 
290
  # the decoder has a different config
291
  decoder_config = BartConfig(self.config.to_dict())
292
- decoder_config.max_position_embeddings = OUTPUT_LENGTH
293
- decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
294
  self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
295
 
296
  class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
297
  def setup(self):
 
 
 
298
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
299
  self.lm_head = nn.Dense(
300
- OUTPUT_VOCAB_SIZE,
301
  use_bias=False,
302
  dtype=self.dtype,
303
  kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
304
  )
305
- self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
306
 
307
  class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
308
  module_class = CustomFlaxBartForConditionalGenerationModule
 
271
 
272
  class CustomFlaxBartModule(FlaxBartModule):
273
  def setup(self):
274
+ # check config is valid, otherwise set default values
275
+ self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
276
+ self.config.max_position_embeddings_decoder = getattr(self.config, 'vocab_size_output', OUTPUT_LENGTH)
277
+
278
  # we keep shared to easily load pre-trained weights
279
  self.shared = nn.Embed(
280
  self.config.vocab_size,
 
284
  )
285
  # a separate embedding is used for the decoder
286
  self.decoder_embed = nn.Embed(
287
+ self.config.vocab_size_output,
288
  self.config.d_model,
289
  embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
290
  dtype=self.dtype,
 
293
 
294
  # the decoder has a different config
295
  decoder_config = BartConfig(self.config.to_dict())
296
+ decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder
297
+ decoder_config.vocab_size = self.config.vocab_size_output
298
  self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
299
 
300
  class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
301
  def setup(self):
302
+ # check config is valid, otherwise set default values
303
+ self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
304
+
305
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
306
  self.lm_head = nn.Dense(
307
+ self.config.vocab_size_output,
308
  use_bias=False,
309
  dtype=self.dtype,
310
  kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
311
  )
312
+ self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output))
313
 
314
  class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
315
  module_class = CustomFlaxBartForConditionalGenerationModule