vit-gpt2 / vit_gpt2 /configuration_vit_gpt2.py
ydshieh
clean up config script
ac161f7
raw
history blame
1.46 kB
import copy
from transformers import ViTConfig, GPT2Config
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class ViTGPT2Config(PretrainedConfig):
model_type = "vit-gpt2"
is_composition = True
def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
super().__init__(text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs)
if vision_config_dict is None:
vision_config_dict = {}
logger.info("vision_config_dict is None. initializing the ViTConfig with default values.")
if text_config_dict is None:
text_config_dict = {}
logger.info("text_config_dict is None. Initializing the GPT2Config with default values.")
self.vision_config = ViTConfig(**vision_config_dict)
self.text_config = GPT2Config(**text_config_dict)
@classmethod
def from_vit_gpt2_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):
return cls(vision_config_dict=vision_config.to_dict(), text_config_dict=text_config.to_dict(), **kwargs)
def to_dict(self):
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output