ydshieh commited on
Commit
f2e4555
1 Parent(s): 38eed3e
run_image_captioning_flax_reduced.py CHANGED
@@ -510,20 +510,18 @@ def main():
510
  if decoder_config.pad_token_id is None:
511
  decoder_config.pad_token_id = decoder_config.eos_token_id
512
 
513
- config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
514
- # Necessary for Flax's generate()
515
- config.decoder_start_token_id = config.decoder.decoder_start_token_id
516
-
517
  model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
518
  encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path,
519
  decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path,
520
- encoder_config=config.encoder,
521
- decoder_config=config.decoder,
522
  encoder_seed=training_args.seed,
523
  decoder_seed=training_args.seed,
524
  encoder_dtype=getattr(jnp, model_args.dtype),
525
  decoder_dtype=getattr(jnp, model_args.dtype),
526
  )
 
 
527
 
528
  if model_args.feature_extractor_name:
529
  feature_extractor = AutoFeatureExtractor.from_pretrained(
@@ -553,7 +551,7 @@ def main():
553
  "You are instantiating a new tokenizer from scratch. This is not supported by this script."
554
  "You can do it from another script, save it, and load it from here, using --tokenizer_name."
555
  )
556
- tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.decoder.pad_token_id)
557
 
558
  # Preprocessing the datasets.
559
  # We need to tokenize inputs and targets.
@@ -628,7 +626,7 @@ def main():
628
 
629
  model_inputs["labels"] = labels["input_ids"]
630
  decoder_input_ids = shift_tokens_right_fn(
631
- labels["input_ids"], config.decoder.pad_token_id, config.decoder.decoder_start_token_id
632
  )
633
  model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
634
 
@@ -687,9 +685,9 @@ def main():
687
  {
688
  "pixel_values": datasets.Array3D(
689
  shape=(
690
- getattr(config.encoder, "num_channels", 3),
691
- config.encoder.image_size,
692
- config.encoder.image_size,
693
  ),
694
  dtype="float32",
695
  ),
 
510
  if decoder_config.pad_token_id is None:
511
  decoder_config.pad_token_id = decoder_config.eos_token_id
512
 
 
 
 
 
513
  model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
514
  encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path,
515
  decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path,
516
+ encoder_config=encoder_config,
517
+ decoder_config=decoder_config,
518
  encoder_seed=training_args.seed,
519
  decoder_seed=training_args.seed,
520
  encoder_dtype=getattr(jnp, model_args.dtype),
521
  decoder_dtype=getattr(jnp, model_args.dtype),
522
  )
523
+ # Necessary for Flax's generate()
524
+ model.config.decoder_start_token_id = decoder_config.decoder_start_token_id
525
 
526
  if model_args.feature_extractor_name:
527
  feature_extractor = AutoFeatureExtractor.from_pretrained(
 
551
  "You are instantiating a new tokenizer from scratch. This is not supported by this script."
552
  "You can do it from another script, save it, and load it from here, using --tokenizer_name."
553
  )
554
+ tokenizer.pad_token = tokenizer.convert_ids_to_tokens(model.config.decoder.pad_token_id)
555
 
556
  # Preprocessing the datasets.
557
  # We need to tokenize inputs and targets.
 
626
 
627
  model_inputs["labels"] = labels["input_ids"]
628
  decoder_input_ids = shift_tokens_right_fn(
629
+ labels["input_ids"], model.config.decoder.pad_token_id, model.config.decoder_start_token_id
630
  )
631
  model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
632
 
 
685
  {
686
  "pixel_values": datasets.Array3D(
687
  shape=(
688
+ getattr(model.config.encoder, "num_channels", 3),
689
+ model.config.encoder.image_size,
690
+ model.config.encoder.image_size,
691
  ),
692
  dtype="float32",
693
  ),