Pedro Cuenca commited on
Commit
a841a4c
1 Parent(s): a104edb

Decoder: set eos to an unreachable value, set min_length=max_length to

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +5 -1
seq2seq/run_seq2seq_flax.py CHANGED
@@ -258,6 +258,8 @@ class CustomFlaxBartModule(FlaxBartModule):
258
  # the decoder has a different config
259
  decoder_config = BartConfig(self.config.to_dict())
260
  decoder_config.max_position_embeddings = OUTPUT_LENGTH
 
 
261
  decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
262
  self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
263
 
@@ -407,7 +409,9 @@ def main():
407
  config.decoder_start_token_id = BOS_TOKEN_ID
408
  config.bos_token_id = BOS_TOKEN_ID # should not be used
409
  config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
410
- config.eos_token_id = None # prevents generation from stopping until we reach max_length
 
 
411
 
412
 
413
  # Create a custom model and initialize it randomly
 
258
  # the decoder has a different config
259
  decoder_config = BartConfig(self.config.to_dict())
260
  decoder_config.max_position_embeddings = OUTPUT_LENGTH
261
+ decoder_config.min_length = OUTPUT_LENGTH
262
+ decoder_config.max_length = OUTPUT_LENGTH
263
  decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
264
  self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
265
 
 
409
  config.decoder_start_token_id = BOS_TOKEN_ID
410
  config.bos_token_id = BOS_TOKEN_ID # should not be used
411
  config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
412
+ config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
413
+ #config.min_length = data_args.max_target_length # Set only in decoder?
414
+ #config.max_length = data_args.max_target_length # Set only in decoder?
415
 
416
 
417
  # Create a custom model and initialize it randomly