ydshieh commited on
Commit
bcae421
1 Parent(s): 15ecbe8
run_image_captioning_flax_reduced.py CHANGED
@@ -524,6 +524,7 @@ def main():
524
  decoder_dtype=getattr(jnp, model_args.dtype),
525
  )
526
  # necessary to make Flax's generate() work
 
527
  model.config.decoder_start_token_id = decoder_config.decoder_start_token_id
528
 
529
  if model_args.feature_extractor_name:
 
524
  decoder_dtype=getattr(jnp, model_args.dtype),
525
  )
526
  # necessary to make Flax's generate() work
527
+ model.config.eos_token_id = decoder_config.eos_token_id
528
  model.config.decoder_start_token_id = decoder_config.decoder_start_token_id
529
 
530
  if model_args.feature_extractor_name: