multimodal / transformers /examples /flax /image-captioning /create_model_from_encoder_decoder_models.py
Li
add transformers
455a40f
raw
history blame
4.49 kB
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Create a VisionEncoderDecoderModel instance from pretrained encoder/decoder models.
The cross-attention will be randomly initialized.
"""
from dataclasses import dataclass, field
from typing import Optional
from transformers import AutoConfig, AutoImageProcessor, AutoTokenizer, FlaxVisionEncoderDecoderModel, HfArgumentParser
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
output_dir: str = field(
metadata={"help": "The output directory where the model will be written."},
)
encoder_model_name_or_path: str = field(
metadata={
"help": (
"The encoder model checkpoint for weights initialization."
"Don't set if you want to train an encoder model from scratch."
)
},
)
decoder_model_name_or_path: str = field(
metadata={
"help": (
"The decoder model checkpoint for weights initialization."
"Don't set if you want to train a decoder model from scratch."
)
},
)
encoder_config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained encoder config name or path if not the same as encoder_model_name"}
)
decoder_config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained decoder config name or path if not the same as decoder_model_name"}
)
def main():
parser = HfArgumentParser((ModelArguments,))
(model_args,) = parser.parse_args_into_dataclasses()
# Load pretrained model and tokenizer
# Use explicit specified encoder config
if model_args.encoder_config_name:
encoder_config = AutoConfig.from_pretrained(model_args.encoder_config_name)
# Use pretrained encoder model's config
else:
encoder_config = AutoConfig.from_pretrained(model_args.encoder_model_name_or_path)
# Use explicit specified decoder config
if model_args.decoder_config_name:
decoder_config = AutoConfig.from_pretrained(model_args.decoder_config_name)
# Use pretrained decoder model's config
else:
decoder_config = AutoConfig.from_pretrained(model_args.decoder_model_name_or_path)
# necessary for `from_encoder_decoder_pretrained` when `decoder_config` is passed
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path,
decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path,
encoder_config=encoder_config,
decoder_config=decoder_config,
)
# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
decoder_start_token_id = decoder_config.decoder_start_token_id
pad_token_id = decoder_config.pad_token_id
if decoder_start_token_id is None:
decoder_start_token_id = decoder_config.bos_token_id
if pad_token_id is None:
pad_token_id = decoder_config.eos_token_id
# This is necessary to make Flax's generate() work
model.config.eos_token_id = decoder_config.eos_token_id
model.config.decoder_start_token_id = decoder_start_token_id
model.config.pad_token_id = pad_token_id
image_processor = AutoImageProcessor.from_pretrained(model_args.encoder_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_args.decoder_model_name_or_path)
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(model.config.pad_token_id)
model.save_pretrained(model_args.output_dir)
image_processor.save_pretrained(model_args.output_dir)
tokenizer.save_pretrained(model_args.output_dir)
if __name__ == "__main__":
main()