boris commited on
Commit
5aaf9df
1 Parent(s): eb591ff

fix: model config

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +2 -4
seq2seq/run_seq2seq_flax.py CHANGED
@@ -282,8 +282,6 @@ class CustomFlaxBartModule(FlaxBartModule):
282
  # the decoder has a different config
283
  decoder_config = BartConfig(self.config.to_dict())
284
  decoder_config.max_position_embeddings = OUTPUT_LENGTH
285
- decoder_config.min_length = OUTPUT_LENGTH
286
- decoder_config.max_length = OUTPUT_LENGTH
287
  decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
288
  self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
289
 
@@ -440,8 +438,8 @@ def main():
440
  config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
441
  config.forced_bos_token_id = None # we don't need this token
442
  config.forced_eos_token_id = None # we don't need this token
443
- #config.min_length = data_args.max_target_length # Set only in decoder?
444
- #config.max_length = data_args.max_target_length # Set only in decoder?
445
 
446
  print(f"TPUs: {jax.device_count()}")
447
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
 
282
  # the decoder has a different config
283
  decoder_config = BartConfig(self.config.to_dict())
284
  decoder_config.max_position_embeddings = OUTPUT_LENGTH
 
 
285
  decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
286
  self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
287
 
 
438
  config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
439
  config.forced_bos_token_id = None # we don't need this token
440
  config.forced_eos_token_id = None # we don't need this token
441
+ config.min_length = data_args.max_target_length
442
+ config.max_length = data_args.max_target_length
443
 
444
  print(f"TPUs: {jax.device_count()}")
445
  assert jax.device_count() == 8, "TPUs in use, please check running processes"