|
import copy |
|
|
|
from transformers import ViTConfig, GPT2Config |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class ViTGPT2Config(PretrainedConfig): |
|
|
|
model_type = "vit-gpt2" |
|
is_composition = True |
|
keys_to_ignore_at_inference = ["past_key_values"] |
|
|
|
def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs): |
|
super().__init__( |
|
text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs) |
|
|
|
if vision_config_dict is None: |
|
vision_config_dict = {} |
|
logger.info("vision_config_dict is None. initializing the ViTConfig with default values.") |
|
|
|
if text_config_dict is None: |
|
text_config_dict = {} |
|
logger.info("text_config_dict is None. Initializing the GPT2Config with default values.") |
|
|
|
self.vision_config = ViTConfig(**vision_config_dict) |
|
self.text_config = GPT2Config(**text_config_dict) |
|
|
|
self.is_encoder_decoder = True |
|
|
|
self.decoder_start_token_id = self.text_config.bos_token_id |
|
self.forced_eos_token_id = self.text_config.eos_token_id |
|
|
|
@classmethod |
|
def from_vit_gpt2_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs): |
|
|
|
return cls(vision_config_dict=vision_config.to_dict(), text_config_dict=text_config.to_dict(), **kwargs) |
|
|
|
def to_dict(self): |
|
output = copy.deepcopy(self.__dict__) |
|
output["vision_config"] = self.vision_config.to_dict() |
|
output["text_config"] = self.text_config.to_dict() |
|
output["model_type"] = self.__class__.model_type |
|
return output |
|
|