import copy from transformers import ViTConfig, GPT2Config from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) class ViTGPT2Config(PretrainedConfig): model_type = "vit-gpt2" is_composition = True keys_to_ignore_at_inference = ["past_key_values"] def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs): super().__init__( vision_config_dict=vision_config_dict, text_config_dict=text_config_dict, **kwargs ) project_encoder = kwargs.pop("project_encoder", None) if vision_config_dict is None: vision_config_dict = {} logger.info("vision_config_dict is None. initializing the ViTConfig with default values.") if text_config_dict is None: text_config_dict = {} logger.info("text_config_dict is None. Initializing the GPT2Config with default values.") self.vision_config = ViTConfig(**vision_config_dict) self.text_config = GPT2Config(**text_config_dict) self.is_encoder_decoder = True # Required in `generate()`. self.bos_token_id = self.text_config.bos_token_id self.eos_token_id = self.text_config.eos_token_id assert hasattr(self.text_config, 'pad_token_id') self.pad_token_id = self.text_config.pad_token_id self.decoder_start_token_id = self.text_config.bos_token_id self.forced_eos_token_id = self.text_config.eos_token_id _project_encoder = getattr(self.text_config, "project_encoder", None) if project_encoder is not None and _project_encoder is not None: assert project_encoder == _project_encoder elif project_encoder is not None: _project_encoder = project_encoder elif _project_encoder is not None: project_encoder = _project_encoder else: project_encoder = False self.project_encoder = project_encoder self.text_config.project_encoder = project_encoder @classmethod def from_vision_text_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs): return cls(vision_config_dict=vision_config.to_dict(), text_config_dict=text_config.to_dict(), **kwargs) def to_dict(self): output = copy.deepcopy(self.__dict__) output["vision_config"] = self.vision_config.to_dict() output["text_config"] = self.text_config.to_dict() output["model_type"] = self.__class__.model_type return output