import copy from transformers import ViTConfig, GPT2Config from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) class ViTGPT2Config(PretrainedConfig): model_type = "vit-gpt2" is_composition = True keys_to_ignore_at_inference = ["past_key_values"] def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs): super().__init__( text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs) if vision_config_dict is None: vision_config_dict = {} logger.info("vision_config_dict is None. initializing the ViTConfig with default values.") if text_config_dict is None: text_config_dict = {} logger.info("text_config_dict is None. Initializing the GPT2Config with default values.") self.vision_config = ViTConfig(**vision_config_dict) self.text_config = GPT2Config(**text_config_dict) self.is_encoder_decoder = True self.decoder_start_token_id = self.text_config.bos_token_id self.forced_eos_token_id = self.text_config.eos_token_id @classmethod def from_vit_gpt2_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs): return cls(vision_config_dict=vision_config.to_dict(), text_config_dict=text_config.to_dict(), **kwargs) def to_dict(self): output = copy.deepcopy(self.__dict__) output["vision_config"] = self.vision_config.to_dict() output["text_config"] = self.text_config.to_dict() output["model_type"] = self.__class__.model_type return output