ydshieh
commited on
Commit
•
f2e4555
1
Parent(s):
38eed3e
update 7
Browse files
run_image_captioning_flax_reduced.py
CHANGED
@@ -510,20 +510,18 @@ def main():
|
|
510 |
if decoder_config.pad_token_id is None:
|
511 |
decoder_config.pad_token_id = decoder_config.eos_token_id
|
512 |
|
513 |
-
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
|
514 |
-
# Necessary for Flax's generate()
|
515 |
-
config.decoder_start_token_id = config.decoder.decoder_start_token_id
|
516 |
-
|
517 |
model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
518 |
encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path,
|
519 |
decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path,
|
520 |
-
encoder_config=
|
521 |
-
decoder_config=
|
522 |
encoder_seed=training_args.seed,
|
523 |
decoder_seed=training_args.seed,
|
524 |
encoder_dtype=getattr(jnp, model_args.dtype),
|
525 |
decoder_dtype=getattr(jnp, model_args.dtype),
|
526 |
)
|
|
|
|
|
527 |
|
528 |
if model_args.feature_extractor_name:
|
529 |
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
@@ -553,7 +551,7 @@ def main():
|
|
553 |
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
554 |
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
555 |
)
|
556 |
-
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.decoder.pad_token_id)
|
557 |
|
558 |
# Preprocessing the datasets.
|
559 |
# We need to tokenize inputs and targets.
|
@@ -628,7 +626,7 @@ def main():
|
|
628 |
|
629 |
model_inputs["labels"] = labels["input_ids"]
|
630 |
decoder_input_ids = shift_tokens_right_fn(
|
631 |
-
labels["input_ids"], config.decoder.pad_token_id, config.
|
632 |
)
|
633 |
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
634 |
|
@@ -687,9 +685,9 @@ def main():
|
|
687 |
{
|
688 |
"pixel_values": datasets.Array3D(
|
689 |
shape=(
|
690 |
-
getattr(config.encoder, "num_channels", 3),
|
691 |
-
config.encoder.image_size,
|
692 |
-
config.encoder.image_size,
|
693 |
),
|
694 |
dtype="float32",
|
695 |
),
|
|
|
510 |
if decoder_config.pad_token_id is None:
|
511 |
decoder_config.pad_token_id = decoder_config.eos_token_id
|
512 |
|
|
|
|
|
|
|
|
|
513 |
model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
514 |
encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path,
|
515 |
decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path,
|
516 |
+
encoder_config=encoder_config,
|
517 |
+
decoder_config=decoder_config,
|
518 |
encoder_seed=training_args.seed,
|
519 |
decoder_seed=training_args.seed,
|
520 |
encoder_dtype=getattr(jnp, model_args.dtype),
|
521 |
decoder_dtype=getattr(jnp, model_args.dtype),
|
522 |
)
|
523 |
+
# Necessary for Flax's generate()
|
524 |
+
model.config.decoder_start_token_id = decoder_config.decoder_start_token_id
|
525 |
|
526 |
if model_args.feature_extractor_name:
|
527 |
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
|
|
551 |
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
552 |
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
553 |
)
|
554 |
+
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(model.config.decoder.pad_token_id)
|
555 |
|
556 |
# Preprocessing the datasets.
|
557 |
# We need to tokenize inputs and targets.
|
|
|
626 |
|
627 |
model_inputs["labels"] = labels["input_ids"]
|
628 |
decoder_input_ids = shift_tokens_right_fn(
|
629 |
+
labels["input_ids"], model.config.decoder.pad_token_id, model.config.decoder_start_token_id
|
630 |
)
|
631 |
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
632 |
|
|
|
685 |
{
|
686 |
"pixel_values": datasets.Array3D(
|
687 |
shape=(
|
688 |
+
getattr(model.config.encoder, "num_channels", 3),
|
689 |
+
model.config.encoder.image_size,
|
690 |
+
model.config.encoder.image_size,
|
691 |
),
|
692 |
dtype="float32",
|
693 |
),
|