ydshieh commited on
Commit
ddad56c
1 Parent(s): ea4daa2
Files changed (1) hide show
  1. run_image_captioning_flax_reduced.py +14 -11
run_image_captioning_flax_reduced.py CHANGED
@@ -507,12 +507,6 @@ def main():
507
  decoder_config.is_decoder = True
508
  decoder_config.add_cross_attention = True
509
 
510
- # GPT2 only has bos/eos token but not decoder_start/pad token
511
- if decoder_config.decoder_start_token_id is None:
512
- decoder_config.decoder_start_token_id = decoder_config.bos_token_id
513
- if decoder_config.pad_token_id is None:
514
- decoder_config.pad_token_id = decoder_config.eos_token_id
515
-
516
  model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
517
  encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path,
518
  decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path,
@@ -523,10 +517,19 @@ def main():
523
  encoder_dtype=getattr(jnp, model_args.dtype),
524
  decoder_dtype=getattr(jnp, model_args.dtype),
525
  )
526
- # necessary to make Flax's generate() work
 
 
 
 
 
 
 
 
 
527
  model.config.eos_token_id = decoder_config.eos_token_id
528
- model.config.decoder_start_token_id = decoder_config.decoder_start_token_id
529
- model.config.pad_token_id = decoder_config.pad_token_id
530
 
531
  if model_args.feature_extractor_name:
532
  feature_extractor = AutoFeatureExtractor.from_pretrained(
@@ -556,7 +559,7 @@ def main():
556
  "You are instantiating a new tokenizer from scratch. This is not supported by this script."
557
  "You can do it from another script, save it, and load it from here, using --tokenizer_name."
558
  )
559
- tokenizer.pad_token = tokenizer.convert_ids_to_tokens(model.config.decoder.pad_token_id)
560
 
561
  # Preprocessing the datasets.
562
  # We need to tokenize inputs and targets.
@@ -631,7 +634,7 @@ def main():
631
 
632
  model_inputs["labels"] = labels["input_ids"]
633
  decoder_input_ids = shift_tokens_right_fn(
634
- labels["input_ids"], model.config.decoder.pad_token_id, model.config.decoder_start_token_id
635
  )
636
  model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
637
 
 
507
  decoder_config.is_decoder = True
508
  decoder_config.add_cross_attention = True
509
 
 
 
 
 
 
 
510
  model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
511
  encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path,
512
  decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path,
 
517
  encoder_dtype=getattr(jnp, model_args.dtype),
518
  decoder_dtype=getattr(jnp, model_args.dtype),
519
  )
520
+
521
+ # GPT2 only has bos/eos tokens but not decoder_start/pad tokens
522
+ decoder_start_token_id = decoder_config.decoder_start_token_id
523
+ pad_token_id = decoder_config.pad_token_id
524
+ if decoder_start_token_id is None:
525
+ decoder_config.pad_token_id = decoder_config.bos_token_id
526
+ if pad_token_id is None:
527
+ pad_token_id = decoder_config.pad_token_id
528
+
529
+ # This is necessary to make Flax's generate() work
530
  model.config.eos_token_id = decoder_config.eos_token_id
531
+ model.config.decoder_start_token_id = decoder_start_token_id
532
+ model.config.pad_token_id = pad_token_id
533
 
534
  if model_args.feature_extractor_name:
535
  feature_extractor = AutoFeatureExtractor.from_pretrained(
 
559
  "You are instantiating a new tokenizer from scratch. This is not supported by this script."
560
  "You can do it from another script, save it, and load it from here, using --tokenizer_name."
561
  )
562
+ tokenizer.pad_token = tokenizer.convert_ids_to_tokens(model.config.pad_token_id)
563
 
564
  # Preprocessing the datasets.
565
  # We need to tokenize inputs and targets.
 
634
 
635
  model_inputs["labels"] = labels["input_ids"]
636
  decoder_input_ids = shift_tokens_right_fn(
637
+ labels["input_ids"], model.config.pad_token_id, model.config.decoder_start_token_id
638
  )
639
  model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
640