boris commited on
Commit
a173dad
1 Parent(s): 4aced93

Former-commit-id: 812d34f157e905ab47e1081b79cc0a80b37fd19b

Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +1 -1
seq2seq/run_seq2seq_flax.py CHANGED
@@ -273,7 +273,7 @@ class CustomFlaxBartModule(FlaxBartModule):
273
  def setup(self):
274
  # check config is valid, otherwise set default values
275
  self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
276
- self.config.max_position_embeddings_decoder = getattr(self.config, 'vocab_size_output', OUTPUT_LENGTH)
277
 
278
  # we keep shared to easily load pre-trained weights
279
  self.shared = nn.Embed(
 
273
  def setup(self):
274
  # check config is valid, otherwise set default values
275
  self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
276
+ self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH)
277
 
278
  # we keep shared to easily load pre-trained weights
279
  self.shared = nn.Embed(